mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-02-06 07:23:37 +08:00
work on async messaging
This commit is contained in:
parent
b46b1cbca8
commit
c8a6c3733d
6
.idea/dictionaries/ericf.xml
generated
6
.idea/dictionaries/ericf.xml
generated
@ -159,6 +159,7 @@
|
|||||||
<w>availmins</w>
|
<w>availmins</w>
|
||||||
<w>availplug</w>
|
<w>availplug</w>
|
||||||
<w>aval</w>
|
<w>aval</w>
|
||||||
|
<w>awaitable</w>
|
||||||
<w>axismotion</w>
|
<w>axismotion</w>
|
||||||
<w>bacfg</w>
|
<w>bacfg</w>
|
||||||
<w>backgrounded</w>
|
<w>backgrounded</w>
|
||||||
@ -303,6 +304,7 @@
|
|||||||
<w>capturetheflag</w>
|
<w>capturetheflag</w>
|
||||||
<w>carentity</w>
|
<w>carentity</w>
|
||||||
<w>cashregistersound</w>
|
<w>cashregistersound</w>
|
||||||
|
<w>cbgn</w>
|
||||||
<w>cbits</w>
|
<w>cbits</w>
|
||||||
<w>cbot</w>
|
<w>cbot</w>
|
||||||
<w>cbtn</w>
|
<w>cbtn</w>
|
||||||
@ -313,6 +315,7 @@
|
|||||||
<w>cdrk</w>
|
<w>cdrk</w>
|
||||||
<w>cdull</w>
|
<w>cdull</w>
|
||||||
<w>cdval</w>
|
<w>cdval</w>
|
||||||
|
<w>cend</w>
|
||||||
<w>centeuro</w>
|
<w>centeuro</w>
|
||||||
<w>centiseconds</w>
|
<w>centiseconds</w>
|
||||||
<w>cfconfig</w>
|
<w>cfconfig</w>
|
||||||
@ -1869,6 +1872,8 @@
|
|||||||
<w>rawpaths</w>
|
<w>rawpaths</w>
|
||||||
<w>rcade</w>
|
<w>rcade</w>
|
||||||
<w>rcfile</w>
|
<w>rcfile</w>
|
||||||
|
<w>rcva</w>
|
||||||
|
<w>rcvs</w>
|
||||||
<w>rdict</w>
|
<w>rdict</w>
|
||||||
<w>rdir</w>
|
<w>rdir</w>
|
||||||
<w>readline</w>
|
<w>readline</w>
|
||||||
@ -1997,6 +2002,7 @@
|
|||||||
<w>selwidget</w>
|
<w>selwidget</w>
|
||||||
<w>selwidgets</w>
|
<w>selwidgets</w>
|
||||||
<w>sendable</w>
|
<w>sendable</w>
|
||||||
|
<w>sendmethod</w>
|
||||||
<w>senze</w>
|
<w>senze</w>
|
||||||
<w>seqtype</w>
|
<w>seqtype</w>
|
||||||
<w>seqtypestr</w>
|
<w>seqtypestr</w>
|
||||||
|
|||||||
@ -69,6 +69,7 @@
|
|||||||
<w>availmins</w>
|
<w>availmins</w>
|
||||||
<w>avel</w>
|
<w>avel</w>
|
||||||
<w>avels</w>
|
<w>avels</w>
|
||||||
|
<w>awaitable</w>
|
||||||
<w>axismotion</w>
|
<w>axismotion</w>
|
||||||
<w>backgrounded</w>
|
<w>backgrounded</w>
|
||||||
<w>backgrounding</w>
|
<w>backgrounding</w>
|
||||||
@ -155,10 +156,12 @@
|
|||||||
<w>cancelbtn</w>
|
<w>cancelbtn</w>
|
||||||
<w>capitan</w>
|
<w>capitan</w>
|
||||||
<w>cargs</w>
|
<w>cargs</w>
|
||||||
|
<w>cbgn</w>
|
||||||
<w>cbtnoffs</w>
|
<w>cbtnoffs</w>
|
||||||
<w>ccdd</w>
|
<w>ccdd</w>
|
||||||
<w>ccontext</w>
|
<w>ccontext</w>
|
||||||
<w>ccylinder</w>
|
<w>ccylinder</w>
|
||||||
|
<w>cend</w>
|
||||||
<w>centiseconds</w>
|
<w>centiseconds</w>
|
||||||
<w>cfgdir</w>
|
<w>cfgdir</w>
|
||||||
<w>cfgpath</w>
|
<w>cfgpath</w>
|
||||||
@ -860,6 +863,8 @@
|
|||||||
<w>rasterizer</w>
|
<w>rasterizer</w>
|
||||||
<w>rawkey</w>
|
<w>rawkey</w>
|
||||||
<w>rcade</w>
|
<w>rcade</w>
|
||||||
|
<w>rcva</w>
|
||||||
|
<w>rcvs</w>
|
||||||
<w>reaaaly</w>
|
<w>reaaaly</w>
|
||||||
<w>readset</w>
|
<w>readset</w>
|
||||||
<w>realloc</w>
|
<w>realloc</w>
|
||||||
@ -922,6 +927,7 @@
|
|||||||
<w>selwidget</w>
|
<w>selwidget</w>
|
||||||
<w>selwidgets</w>
|
<w>selwidgets</w>
|
||||||
<w>sendable</w>
|
<w>sendable</w>
|
||||||
|
<w>sendmethod</w>
|
||||||
<w>seqlen</w>
|
<w>seqlen</w>
|
||||||
<w>seqtype</w>
|
<w>seqtype</w>
|
||||||
<w>seqtypestr</w>
|
<w>seqtypestr</w>
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
|
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
|
||||||
<h4><em>last updated on 2021-09-19 for Ballistica version 1.6.5 build 20393</em></h4>
|
<h4><em>last updated on 2021-09-21 for Ballistica version 1.6.5 build 20393</em></h4>
|
||||||
<p>This page documents the Python classes and functions in the 'ba' module,
|
<p>This page documents the Python classes and functions in the 'ba' module,
|
||||||
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
|
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
|
||||||
<hr>
|
<hr>
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
from typing import TYPE_CHECKING, overload
|
from typing import TYPE_CHECKING, overload
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@ -17,55 +18,55 @@ from efro.message import (Message, Response, MessageProtocol, MessageSender,
|
|||||||
MessageReceiver)
|
MessageReceiver)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import List, Type, Any, Callable, Union, Optional
|
from typing import List, Type, Any, Callable, Union, Optional, Awaitable
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TMessage1(Message):
|
class _TMsg1(Message):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
ival: int
|
ival: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_response_types(cls) -> List[Type[Response]]:
|
def get_response_types(cls) -> List[Type[Response]]:
|
||||||
return [_TResponse1]
|
return [_TResp1]
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TMessage2(Message):
|
class _TMsg2(Message):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
sval: str
|
sval: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_response_types(cls) -> List[Type[Response]]:
|
def get_response_types(cls) -> List[Type[Response]]:
|
||||||
return [_TResponse1, _TResponse2]
|
return [_TResp1, _TResp2]
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TMessage3(Message):
|
class _TMsg3(Message):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
sval: str
|
sval: str
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TResponse1(Response):
|
class _TResp1(Response):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
bval: bool
|
bval: bool
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TResponse2(Response):
|
class _TResp2(Response):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
fval: float
|
fval: float
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TResponse3(Message):
|
class _TResp3(Message):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
fval: float
|
fval: float
|
||||||
|
|
||||||
@ -91,55 +92,73 @@ class _BoundTestMessageSender:
|
|||||||
self._sender = sender
|
self._sender = sender
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def send(self, message: _TMessage1) -> _TResponse1:
|
def send(self, message: _TMsg1) -> _TResp1:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]:
|
def send(self, message: _TMsg2) -> Union[_TResp1, _TResp2]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def send(self, message: _TMessage3) -> None:
|
def send(self, message: _TMsg3) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def send(self, message: Message) -> Optional[Response]:
|
def send(self, message: Message) -> Optional[Response]:
|
||||||
"""Send a message."""
|
"""Send a message synchronously."""
|
||||||
return self._sender.send(self._obj, message)
|
return self._sender.send(self._obj, message)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def send_async(self, message: _TMsg1) -> _TResp1:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def send_async(self, message: _TMsg2) -> Union[_TResp1, _TResp2]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def send_async(self, message: _TMsg3) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def send_async(self, message: Message) -> Optional[Response]:
|
||||||
|
"""Send a message asynchronously."""
|
||||||
|
return await self._sender.send_async(self._obj, message)
|
||||||
|
|
||||||
|
|
||||||
# SEND_CODE_TEST_END
|
# SEND_CODE_TEST_END
|
||||||
# RCV_CODE_TEST_BEGIN
|
# RCVS_CODE_TEST_BEGIN
|
||||||
|
|
||||||
|
|
||||||
class _TestMessageReceiver(MessageReceiver):
|
class _TestSyncMessageReceiver(MessageReceiver):
|
||||||
"""Protocol-specific receiver."""
|
"""Protocol-specific synchronous receiver."""
|
||||||
|
|
||||||
|
is_async = False
|
||||||
|
|
||||||
def __get__(
|
def __get__(
|
||||||
self,
|
self,
|
||||||
obj: Any,
|
obj: Any,
|
||||||
type_in: Any = None,
|
type_in: Any = None,
|
||||||
) -> _BoundTestMessageReceiver:
|
) -> _BoundTestSyncMessageReceiver:
|
||||||
return _BoundTestMessageReceiver(obj, self)
|
return _BoundTestSyncMessageReceiver(obj, self)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def handler(
|
def handler(
|
||||||
self,
|
self,
|
||||||
call: Callable[[Any, _TMessage1], _TResponse1],
|
call: Callable[[Any, _TMsg1], _TResp1],
|
||||||
) -> Callable[[Any, _TMessage1], _TResponse1]:
|
) -> Callable[[Any, _TMsg1], _TResp1]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def handler(
|
def handler(
|
||||||
self,
|
self,
|
||||||
call: Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]],
|
call: Callable[[Any, _TMsg2], Union[_TResp1, _TResp2]],
|
||||||
) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]:
|
) -> Callable[[Any, _TMsg2], Union[_TResp1, _TResp2]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def handler(
|
def handler(
|
||||||
self,
|
self,
|
||||||
call: Callable[[Any, _TMessage3], None],
|
call: Callable[[Any, _TMsg3], None],
|
||||||
) -> Callable[[Any, _TMessage3], None]:
|
) -> Callable[[Any, _TMsg3], None]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def handler(self, call: Callable) -> Callable:
|
def handler(self, call: Callable) -> Callable:
|
||||||
@ -148,34 +167,104 @@ class _TestMessageReceiver(MessageReceiver):
|
|||||||
return call
|
return call
|
||||||
|
|
||||||
|
|
||||||
class _BoundTestMessageReceiver:
|
class _BoundTestSyncMessageReceiver:
|
||||||
"""Protocol-specific bound receiver."""
|
"""Protocol-specific bound receiver."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
obj: Any,
|
obj: Any,
|
||||||
receiver: _TestMessageReceiver,
|
receiver: _TestSyncMessageReceiver,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert obj is not None
|
assert obj is not None
|
||||||
self._obj = obj
|
self._obj = obj
|
||||||
self._receiver = receiver
|
self._receiver = receiver
|
||||||
|
|
||||||
def handle_raw_message(self, message: bytes) -> bytes:
|
def handle_raw_message(self, message: bytes) -> bytes:
|
||||||
"""Handle a raw incoming message."""
|
"""Handle a raw incoming synchronous message."""
|
||||||
return self._receiver.handle_raw_message(self._obj, message)
|
return self._receiver.handle_raw_message(self._obj, message)
|
||||||
|
|
||||||
|
async def handle_raw_message_async(self, message: bytes) -> bytes:
|
||||||
|
"""Handle a raw incoming asynchronous message."""
|
||||||
|
return await self._receiver.handle_raw_message_async(
|
||||||
|
self._obj, message)
|
||||||
|
|
||||||
# RCV_CODE_TEST_END
|
|
||||||
|
# RCVS_CODE_TEST_END
|
||||||
|
# RCVA_CODE_TEST_BEGIN
|
||||||
|
|
||||||
|
|
||||||
|
class _TestAsyncMessageReceiver(MessageReceiver):
|
||||||
|
"""Protocol-specific asynchronous receiver."""
|
||||||
|
|
||||||
|
is_async = True
|
||||||
|
|
||||||
|
def __get__(
|
||||||
|
self,
|
||||||
|
obj: Any,
|
||||||
|
type_in: Any = None,
|
||||||
|
) -> _BoundTestAsyncMessageReceiver:
|
||||||
|
return _BoundTestAsyncMessageReceiver(obj, self)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def handler(
|
||||||
|
self,
|
||||||
|
call: Callable[[Any, _TMsg1], Awaitable[_TResp1]],
|
||||||
|
) -> Callable[[Any, _TMsg1], Awaitable[_TResp1]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def handler(
|
||||||
|
self,
|
||||||
|
call: Callable[[Any, _TMsg2], Awaitable[Union[_TResp1, _TResp2]]],
|
||||||
|
) -> Callable[[Any, _TMsg2], Awaitable[Union[_TResp1, _TResp2]]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def handler(
|
||||||
|
self,
|
||||||
|
call: Callable[[Any, _TMsg3], Awaitable[None]],
|
||||||
|
) -> Callable[[Any, _TMsg3], Awaitable[None]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def handler(self, call: Callable) -> Callable:
|
||||||
|
"""Decorator to register message handlers."""
|
||||||
|
self.register_handler(call)
|
||||||
|
return call
|
||||||
|
|
||||||
|
|
||||||
|
class _BoundTestAsyncMessageReceiver:
|
||||||
|
"""Protocol-specific bound receiver."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
obj: Any,
|
||||||
|
receiver: _TestAsyncMessageReceiver,
|
||||||
|
) -> None:
|
||||||
|
assert obj is not None
|
||||||
|
self._obj = obj
|
||||||
|
self._receiver = receiver
|
||||||
|
|
||||||
|
def handle_raw_message(self, message: bytes) -> bytes:
|
||||||
|
"""Handle a raw incoming synchronous message."""
|
||||||
|
return self._receiver.handle_raw_message(self._obj, message)
|
||||||
|
|
||||||
|
async def handle_raw_message_async(self, message: bytes) -> bytes:
|
||||||
|
"""Handle a raw incoming asynchronous message."""
|
||||||
|
return await self._receiver.handle_raw_message_async(
|
||||||
|
self._obj, message)
|
||||||
|
|
||||||
|
|
||||||
|
# RCVA_CODE_TEST_END
|
||||||
|
|
||||||
TEST_PROTOCOL = MessageProtocol(
|
TEST_PROTOCOL = MessageProtocol(
|
||||||
message_types={
|
message_types={
|
||||||
0: _TMessage1,
|
0: _TMsg1,
|
||||||
1: _TMessage2,
|
1: _TMsg2,
|
||||||
2: _TMessage3,
|
2: _TMsg3,
|
||||||
},
|
},
|
||||||
response_types={
|
response_types={
|
||||||
0: _TResponse1,
|
0: _TResp1,
|
||||||
1: _TResponse2,
|
1: _TResp2,
|
||||||
},
|
},
|
||||||
trusted_sender=True,
|
trusted_sender=True,
|
||||||
log_remote_exceptions=False,
|
log_remote_exceptions=False,
|
||||||
@ -185,20 +274,21 @@ TEST_PROTOCOL = MessageProtocol(
|
|||||||
def test_protocol_creation() -> None:
|
def test_protocol_creation() -> None:
|
||||||
"""Test protocol creation."""
|
"""Test protocol creation."""
|
||||||
|
|
||||||
# This should fail because _TMessage1 can return _TResponse1 which
|
# This should fail because _TMsg1 can return _TResp1 which
|
||||||
# is not given an id here.
|
# is not given an id here.
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_protocol = MessageProtocol(message_types={0: _TMessage1},
|
_protocol = MessageProtocol(message_types={0: _TMsg1},
|
||||||
response_types={0: _TResponse2})
|
response_types={0: _TResp2})
|
||||||
|
|
||||||
# Now it should work.
|
# Now it should work.
|
||||||
_protocol = MessageProtocol(message_types={0: _TMessage1},
|
_protocol = MessageProtocol(message_types={0: _TMsg1},
|
||||||
response_types={0: _TResponse1})
|
response_types={0: _TResp1})
|
||||||
|
|
||||||
|
|
||||||
def test_sender_module_embedded() -> None:
|
def test_sender_module_embedded() -> None:
|
||||||
"""Test generation of protocol-specific sender modules for typing/etc."""
|
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||||
smod = TEST_PROTOCOL.create_sender_module('Test', private=True)
|
smod = TEST_PROTOCOL.create_sender_module('TestMessageSender',
|
||||||
|
private=True)
|
||||||
|
|
||||||
# Clip everything up to our first class declaration.
|
# Clip everything up to our first class declaration.
|
||||||
lines = smod.splitlines()
|
lines = smod.splitlines()
|
||||||
@ -218,13 +308,15 @@ def test_sender_module_embedded() -> None:
|
|||||||
' See test stdout for new code.')
|
' See test stdout for new code.')
|
||||||
|
|
||||||
|
|
||||||
def test_receiver_module_embedded() -> None:
|
def test_receiver_module_sync_embedded() -> None:
|
||||||
"""Test generation of protocol-specific sender modules for typing/etc."""
|
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||||
smod = TEST_PROTOCOL.create_receiver_module('Test', private=True)
|
smod = TEST_PROTOCOL.create_receiver_module('TestSyncMessageReceiver',
|
||||||
|
is_async=False,
|
||||||
|
private=True)
|
||||||
|
|
||||||
# Clip everything up to our first class declaration.
|
# Clip everything up to our first class declaration.
|
||||||
lines = smod.splitlines()
|
lines = smod.splitlines()
|
||||||
classline = lines.index('class _TestMessageReceiver(MessageReceiver):')
|
classline = lines.index('class _TestSyncMessageReceiver(MessageReceiver):')
|
||||||
clipped = '\n'.join(lines[classline:])
|
clipped = '\n'.join(lines[classline:])
|
||||||
|
|
||||||
# This snippet should match what we've got embedded above;
|
# This snippet should match what we've got embedded above;
|
||||||
@ -232,12 +324,39 @@ def test_receiver_module_embedded() -> None:
|
|||||||
with open(__file__, encoding='utf-8') as infile:
|
with open(__file__, encoding='utf-8') as infile:
|
||||||
ourcode = infile.read()
|
ourcode = infile.read()
|
||||||
|
|
||||||
emb = f'# RCV_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCV_CODE_TEST_END\n'
|
emb = f'# RCVS_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVS_CODE_TEST_END\n'
|
||||||
if emb not in ourcode:
|
if emb not in ourcode:
|
||||||
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
|
print(f'EXPECTED SYNC RECEIVER EMBEDDED CODE:\n{emb}')
|
||||||
raise RuntimeError('Generated sender module does not match embedded;'
|
raise RuntimeError(
|
||||||
' test code needs to be updated.'
|
'Generated sync receiver module does not match embedded;'
|
||||||
' See test stdout for new code.')
|
' test code needs to be updated.'
|
||||||
|
' See test stdout for new code.')
|
||||||
|
|
||||||
|
|
||||||
|
def test_receiver_module_async_embedded() -> None:
|
||||||
|
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||||
|
smod = TEST_PROTOCOL.create_receiver_module('TestAsyncMessageReceiver',
|
||||||
|
is_async=True,
|
||||||
|
private=True)
|
||||||
|
|
||||||
|
# Clip everything up to our first class declaration.
|
||||||
|
lines = smod.splitlines()
|
||||||
|
classline = lines.index(
|
||||||
|
'class _TestAsyncMessageReceiver(MessageReceiver):')
|
||||||
|
clipped = '\n'.join(lines[classline:])
|
||||||
|
|
||||||
|
# This snippet should match what we've got embedded above;
|
||||||
|
# If not then we need to update our embedded version.
|
||||||
|
with open(__file__, encoding='utf-8') as infile:
|
||||||
|
ourcode = infile.read()
|
||||||
|
|
||||||
|
emb = f'# RCVA_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVA_CODE_TEST_END\n'
|
||||||
|
if emb not in ourcode:
|
||||||
|
print(f'EXPECTED ASYNC RECEIVER EMBEDDED CODE:\n{emb}')
|
||||||
|
raise RuntimeError(
|
||||||
|
'Generated async receiver module does not match embedded;'
|
||||||
|
' test code needs to be updated.'
|
||||||
|
' See test stdout for new code.')
|
||||||
|
|
||||||
|
|
||||||
def test_receiver_creation() -> None:
|
def test_receiver_creation() -> None:
|
||||||
@ -251,13 +370,13 @@ def test_receiver_creation() -> None:
|
|||||||
class _TestClassR:
|
class _TestClassR:
|
||||||
"""Test class incorporating receive functionality."""
|
"""Test class incorporating receive functionality."""
|
||||||
|
|
||||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
receiver = _TestSyncMessageReceiver(TEST_PROTOCOL)
|
||||||
|
|
||||||
@receiver.handler
|
@receiver.handler
|
||||||
def handle_test_message_2(self, msg: _TMessage2) -> _TResponse2:
|
def handle_test_message_2(self, msg: _TMsg2) -> _TResp2:
|
||||||
"""Test."""
|
"""Test."""
|
||||||
del msg # Unused
|
del msg # Unused
|
||||||
return _TResponse2(fval=1.2)
|
return _TResp2(fval=1.2)
|
||||||
|
|
||||||
# Validation should fail because not all message types in the
|
# Validation should fail because not all message types in the
|
||||||
# protocol are handled.
|
# protocol are handled.
|
||||||
@ -266,13 +385,13 @@ def test_receiver_creation() -> None:
|
|||||||
class _TestClassR2:
|
class _TestClassR2:
|
||||||
"""Test class incorporating receive functionality."""
|
"""Test class incorporating receive functionality."""
|
||||||
|
|
||||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
receiver = _TestSyncMessageReceiver(TEST_PROTOCOL)
|
||||||
|
|
||||||
# Checks that we've added handlers for all message types, etc.
|
# Checks that we've added handlers for all message types, etc.
|
||||||
receiver.validate()
|
receiver.validate()
|
||||||
|
|
||||||
|
|
||||||
def test_synchronous_messaging() -> None:
|
def test_full_pipeline() -> None:
|
||||||
"""Test the full pipeline."""
|
"""Test the full pipeline."""
|
||||||
|
|
||||||
# Define a class that can send messages and one that can receive them.
|
# Define a class that can send messages and one that can receive them.
|
||||||
@ -281,66 +400,119 @@ def test_synchronous_messaging() -> None:
|
|||||||
|
|
||||||
msg = _TestMessageSender(TEST_PROTOCOL)
|
msg = _TestMessageSender(TEST_PROTOCOL)
|
||||||
|
|
||||||
def __init__(self, target: TestClassR) -> None:
|
def __init__(self, target: Union[TestClassRSync,
|
||||||
|
TestClassRAsync]) -> None:
|
||||||
self._target = target
|
self._target = target
|
||||||
|
|
||||||
@msg.send_raw_handler
|
@msg.send_method
|
||||||
def _send_raw_message(self, data: bytes) -> bytes:
|
def _send_raw_message(self, data: bytes) -> bytes:
|
||||||
"""Test."""
|
"""Handle synchronous sending of raw message data."""
|
||||||
|
# Just talk directly to the receiver for this example.
|
||||||
|
# (currently only support synchronous receivers)
|
||||||
|
assert isinstance(self._target, TestClassRSync)
|
||||||
return self._target.receiver.handle_raw_message(data)
|
return self._target.receiver.handle_raw_message(data)
|
||||||
|
|
||||||
class TestClassR:
|
@msg.send_async_method
|
||||||
"""Test class incorporating receive functionality."""
|
async def _send_raw_message_async(self, data: bytes) -> bytes:
|
||||||
|
"""Handle asynchronous sending of raw message data."""
|
||||||
|
# Just talk directly to the receiver for this example.
|
||||||
|
# (we can do sync or async receivers)
|
||||||
|
if isinstance(self._target, TestClassRSync):
|
||||||
|
return self._target.receiver.handle_raw_message(data)
|
||||||
|
return await self._target.receiver.handle_raw_message_async(data)
|
||||||
|
|
||||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
class TestClassRSync:
|
||||||
|
"""Test class incorporating synchronous receive functionality."""
|
||||||
|
|
||||||
|
receiver = _TestSyncMessageReceiver(TEST_PROTOCOL)
|
||||||
|
|
||||||
@receiver.handler
|
@receiver.handler
|
||||||
def handle_test_message_1(self, msg: _TMessage1) -> _TResponse1:
|
def handle_test_message_1(self, msg: _TMsg1) -> _TResp1:
|
||||||
"""Test."""
|
"""Test."""
|
||||||
if msg.ival == 1:
|
if msg.ival == 1:
|
||||||
raise CleanError('Testing Clean Error')
|
raise CleanError('Testing Clean Error')
|
||||||
if msg.ival == 2:
|
if msg.ival == 2:
|
||||||
raise RuntimeError('Testing Runtime Error')
|
raise RuntimeError('Testing Runtime Error')
|
||||||
return _TResponse1(bval=True)
|
return _TResp1(bval=True)
|
||||||
|
|
||||||
@receiver.handler
|
@receiver.handler
|
||||||
def handle_test_message_2(
|
def handle_test_message_2(self,
|
||||||
self, msg: _TMessage2) -> Union[_TResponse1, _TResponse2]:
|
msg: _TMsg2) -> Union[_TResp1, _TResp2]:
|
||||||
"""Test."""
|
"""Test."""
|
||||||
del msg # Unused
|
del msg # Unused
|
||||||
return _TResponse2(fval=1.2)
|
return _TResp2(fval=1.2)
|
||||||
|
|
||||||
@receiver.handler
|
@receiver.handler
|
||||||
def handle_test_message_3(self, msg: _TMessage3) -> None:
|
def handle_test_message_3(self, msg: _TMsg3) -> None:
|
||||||
"""Test."""
|
"""Test."""
|
||||||
del msg # Unused
|
del msg # Unused
|
||||||
|
|
||||||
receiver.validate()
|
receiver.validate()
|
||||||
|
|
||||||
obj_r = TestClassR()
|
class TestClassRAsync:
|
||||||
obj = TestClassS(target=obj_r)
|
"""Test class incorporating asynchronous receive functionality."""
|
||||||
|
|
||||||
response = obj.msg.send(_TMessage1(ival=0))
|
receiver = _TestAsyncMessageReceiver(TEST_PROTOCOL)
|
||||||
assert isinstance(response, _TResponse1)
|
|
||||||
|
|
||||||
response2 = obj.msg.send(_TMessage2(sval='rah'))
|
@receiver.handler
|
||||||
assert isinstance(response2, (_TResponse1, _TResponse2))
|
async def handle_test_message_1(self, msg: _TMsg1) -> _TResp1:
|
||||||
|
"""Test."""
|
||||||
|
if msg.ival == 1:
|
||||||
|
raise CleanError('Testing Clean Error')
|
||||||
|
if msg.ival == 2:
|
||||||
|
raise RuntimeError('Testing Runtime Error')
|
||||||
|
return _TResp1(bval=True)
|
||||||
|
|
||||||
response3 = obj.msg.send(_TMessage3(sval='rah'))
|
@receiver.handler
|
||||||
assert response3 is None
|
async def handle_test_message_2(
|
||||||
|
self, msg: _TMsg2) -> Union[_TResp1, _TResp2]:
|
||||||
|
"""Test."""
|
||||||
|
del msg # Unused
|
||||||
|
return _TResp2(fval=1.2)
|
||||||
|
|
||||||
# Make sure static typing lines up too.
|
@receiver.handler
|
||||||
|
async def handle_test_message_3(self, msg: _TMsg3) -> None:
|
||||||
|
"""Test."""
|
||||||
|
del msg # Unused
|
||||||
|
|
||||||
|
receiver.validate()
|
||||||
|
|
||||||
|
obj_r_sync = TestClassRSync()
|
||||||
|
obj_r_async = TestClassRAsync()
|
||||||
|
obj = TestClassS(target=obj_r_sync)
|
||||||
|
obj2 = TestClassS(target=obj_r_async)
|
||||||
|
|
||||||
|
# Test sends (of sync and async varieties).
|
||||||
|
response1 = obj.msg.send(_TMsg1(ival=0))
|
||||||
|
response2 = obj.msg.send(_TMsg2(sval='rah'))
|
||||||
|
response3 = obj.msg.send(_TMsg3(sval='rah'))
|
||||||
|
response4 = asyncio.run(obj.msg.send_async(_TMsg1(ival=0)))
|
||||||
|
|
||||||
|
# Make sure static typing lines up with what we expect.
|
||||||
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
|
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
|
||||||
assert static_type_equals(response, _TResponse1)
|
assert static_type_equals(response1, _TResp1)
|
||||||
assert static_type_equals(response3, None)
|
assert static_type_equals(response3, None)
|
||||||
|
|
||||||
|
assert isinstance(response1, _TResp1)
|
||||||
|
assert isinstance(response2, (_TResp1, _TResp2))
|
||||||
|
assert response3 is None
|
||||||
|
assert isinstance(response4, _TResp1)
|
||||||
|
|
||||||
# Remote CleanErrors should come across locally as the same.
|
# Remote CleanErrors should come across locally as the same.
|
||||||
try:
|
try:
|
||||||
_response4 = obj.msg.send(_TMessage1(ival=1))
|
_response5 = obj.msg.send(_TMsg1(ival=1))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
assert isinstance(exc, CleanError)
|
assert isinstance(exc, CleanError)
|
||||||
assert str(exc) == 'Testing Clean Error'
|
assert str(exc) == 'Testing Clean Error'
|
||||||
|
|
||||||
# Other remote errors should result in RemoteError.
|
# Other remote errors should result in RemoteError.
|
||||||
with pytest.raises(RemoteError):
|
with pytest.raises(RemoteError):
|
||||||
_response4 = obj.msg.send(_TMessage1(ival=2))
|
_response5 = obj.msg.send(_TMsg1(ival=2))
|
||||||
|
|
||||||
|
# Now test sends to async handlers.
|
||||||
|
response6 = asyncio.run(obj2.msg.send_async(_TMsg1(ival=0)))
|
||||||
|
assert isinstance(response6, _TResp1)
|
||||||
|
|
||||||
|
# Make sure static typing lines up with what we expect.
|
||||||
|
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
|
||||||
|
assert static_type_equals(response6, _TResp1)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
|
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
|
||||||
Sequence, Union)
|
Sequence, Union, Awaitable)
|
||||||
from efro.error import CommunicationError
|
from efro.error import CommunicationError
|
||||||
|
|
||||||
TM = TypeVar('TM', bound='MessageSender')
|
TM = TypeVar('TM', bound='MessageSender')
|
||||||
@ -300,7 +300,7 @@ class MessageProtocol:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def create_sender_module(self,
|
def create_sender_module(self,
|
||||||
classname: str,
|
basename: str,
|
||||||
private: bool = False) -> str:
|
private: bool = False) -> str:
|
||||||
""""Create a Python module defining a MessageSender subclass.
|
""""Create a Python module defining a MessageSender subclass.
|
||||||
|
|
||||||
@ -308,29 +308,35 @@ class MessageProtocol:
|
|||||||
for the varieties of send calls for message/response types defined
|
for the varieties of send calls for message/response types defined
|
||||||
in the protocol.
|
in the protocol.
|
||||||
|
|
||||||
|
Class names are based on basename; a basename 'FooSender' will
|
||||||
|
result in classes FooSender and BoundFooSender.
|
||||||
|
|
||||||
|
If 'private' is True, class-names will be prefixed with an '_'.
|
||||||
|
|
||||||
Note that line lengths are not clipped, so output may need to be
|
Note that line lengths are not clipped, so output may need to be
|
||||||
run through a formatter to prevent lint warnings about excessive
|
run through a formatter to prevent lint warnings about excessive
|
||||||
line lengths.
|
line lengths.
|
||||||
"""
|
"""
|
||||||
|
# pylint: disable=too-many-locals
|
||||||
|
|
||||||
ppre = '_' if private else ''
|
ppre = '_' if private else ''
|
||||||
out = self._get_module_header('sender')
|
out = self._get_module_header('sender')
|
||||||
out += (f'class {ppre}{classname}MessageSender(MessageSender):\n'
|
out += (f'class {ppre}{basename}(MessageSender):\n'
|
||||||
f' """Protocol-specific sender."""\n'
|
f' """Protocol-specific sender."""\n'
|
||||||
f'\n'
|
f'\n'
|
||||||
f' def __get__(self,\n'
|
f' def __get__(self,\n'
|
||||||
f' obj: Any,\n'
|
f' obj: Any,\n'
|
||||||
f' type_in: Any = None)'
|
f' type_in: Any = None)'
|
||||||
f' -> {ppre}Bound{classname}MessageSender:\n'
|
f' -> {ppre}Bound{basename}:\n'
|
||||||
f' return {ppre}Bound{classname}MessageSender'
|
f' return {ppre}Bound{basename}'
|
||||||
f'(obj, self)\n'
|
f'(obj, self)\n'
|
||||||
f'\n'
|
f'\n'
|
||||||
f'\n'
|
f'\n'
|
||||||
f'class {ppre}Bound{classname}MessageSender:\n'
|
f'class {ppre}Bound{basename}:\n'
|
||||||
f' """Protocol-specific bound sender."""\n'
|
f' """Protocol-specific bound sender."""\n'
|
||||||
f'\n'
|
f'\n'
|
||||||
f' def __init__(self, obj: Any,'
|
f' def __init__(self, obj: Any,'
|
||||||
f' sender: {ppre}{classname}MessageSender) -> None:\n'
|
f' sender: {ppre}{basename}) -> None:\n'
|
||||||
f' assert obj is not None\n'
|
f' assert obj is not None\n'
|
||||||
f' self._obj = obj\n'
|
f' self._obj = obj\n'
|
||||||
f' self._sender = sender\n')
|
f' self._sender = sender\n')
|
||||||
@ -352,29 +358,37 @@ class MessageProtocol:
|
|||||||
return 'None' if rtype is EmptyResponse else rtype.__name__
|
return 'None' if rtype is EmptyResponse else rtype.__name__
|
||||||
|
|
||||||
if len(msgtypes) > 1:
|
if len(msgtypes) > 1:
|
||||||
for msgtype in msgtypes:
|
for async_pass in False, True:
|
||||||
msgtypevar = msgtype.__name__
|
pfx = 'async ' if async_pass else ''
|
||||||
rtypes = msgtype.get_response_types()
|
sfx = '_async' if async_pass else ''
|
||||||
if len(rtypes) > 1:
|
awt = 'await ' if async_pass else ''
|
||||||
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
|
how = 'asynchronously' if async_pass else 'synchronously'
|
||||||
rtypevar = f'Union[{tps}]'
|
for msgtype in msgtypes:
|
||||||
else:
|
msgtypevar = msgtype.__name__
|
||||||
rtypevar = _filt_tp_name(rtypes[0])
|
rtypes = msgtype.get_response_types()
|
||||||
|
if len(rtypes) > 1:
|
||||||
|
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
|
||||||
|
rtypevar = f'Union[{tps}]'
|
||||||
|
else:
|
||||||
|
rtypevar = _filt_tp_name(rtypes[0])
|
||||||
|
out += (f'\n'
|
||||||
|
f' @overload\n'
|
||||||
|
f' {pfx}def send{sfx}(self,'
|
||||||
|
f' message: {msgtypevar})'
|
||||||
|
f' -> {rtypevar}:\n'
|
||||||
|
f' ...\n')
|
||||||
out += (f'\n'
|
out += (f'\n'
|
||||||
f' @overload\n'
|
f' {pfx}def send{sfx}(self, message: Message)'
|
||||||
f' def send(self, message: {msgtypevar})'
|
f' -> Optional[Response]:\n'
|
||||||
f' -> {rtypevar}:\n'
|
f' """Send a message {how}."""\n'
|
||||||
f' ...\n')
|
f' return {awt}self._sender.'
|
||||||
out += ('\n'
|
f'send{sfx}(self._obj, message)\n')
|
||||||
' def send(self, message: Message)'
|
|
||||||
' -> Optional[Response]:\n'
|
|
||||||
' """Send a message."""\n'
|
|
||||||
' return self._sender.send(self._obj, message)\n')
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def create_receiver_module(self,
|
def create_receiver_module(self,
|
||||||
classname: str,
|
basename: str,
|
||||||
|
is_async: bool,
|
||||||
private: bool = False) -> str:
|
private: bool = False) -> str:
|
||||||
""""Create a Python module defining a MessageReceiver subclass.
|
""""Create a Python module defining a MessageReceiver subclass.
|
||||||
|
|
||||||
@ -382,21 +396,33 @@ class MessageProtocol:
|
|||||||
for the register method for message/response types defined in
|
for the register method for message/response types defined in
|
||||||
the protocol.
|
the protocol.
|
||||||
|
|
||||||
|
Class names are based on basename; a basename 'FooReceiver' will
|
||||||
|
result in FooReceiver and BoundFooReceiver.
|
||||||
|
|
||||||
|
If 'is_async' is True, handle_raw_message() will be an async method
|
||||||
|
and the @handler decorator will expect async methods.
|
||||||
|
|
||||||
|
If 'private' is True, class-names will be prefixed with an '_'.
|
||||||
|
|
||||||
Note that line lengths are not clipped, so output may need to be
|
Note that line lengths are not clipped, so output may need to be
|
||||||
run through a formatter to prevent lint warnings about excessive
|
run through a formatter to prevent lint warnings about excessive
|
||||||
line lengths.
|
line lengths.
|
||||||
"""
|
"""
|
||||||
|
# pylint: disable=too-many-locals
|
||||||
|
desc = 'asynchronous' if is_async else 'synchronous'
|
||||||
ppre = '_' if private else ''
|
ppre = '_' if private else ''
|
||||||
out = self._get_module_header('receiver')
|
out = self._get_module_header('receiver')
|
||||||
out += (f'class {ppre}{classname}MessageReceiver(MessageReceiver):\n'
|
out += (f'class {ppre}{basename}(MessageReceiver):\n'
|
||||||
f' """Protocol-specific receiver."""\n'
|
f' """Protocol-specific {desc} receiver."""\n'
|
||||||
|
f'\n'
|
||||||
|
f' is_async = {is_async}\n'
|
||||||
f'\n'
|
f'\n'
|
||||||
f' def __get__(\n'
|
f' def __get__(\n'
|
||||||
f' self,\n'
|
f' self,\n'
|
||||||
f' obj: Any,\n'
|
f' obj: Any,\n'
|
||||||
f' type_in: Any = None,\n'
|
f' type_in: Any = None,\n'
|
||||||
f' ) -> {ppre}Bound{classname}MessageReceiver:\n'
|
f' ) -> {ppre}Bound{basename}:\n'
|
||||||
f' return {ppre}Bound{classname}MessageReceiver('
|
f' return {ppre}Bound{basename}('
|
||||||
f'obj, self)\n')
|
f'obj, self)\n')
|
||||||
|
|
||||||
# Define handler() overloads for all registered message types.
|
# Define handler() overloads for all registered message types.
|
||||||
@ -416,6 +442,8 @@ class MessageProtocol:
|
|||||||
return 'None' if rtype is EmptyResponse else rtype.__name__
|
return 'None' if rtype is EmptyResponse else rtype.__name__
|
||||||
|
|
||||||
if len(msgtypes) > 1:
|
if len(msgtypes) > 1:
|
||||||
|
cbgn = 'Awaitable[' if is_async else ''
|
||||||
|
cend = ']' if is_async else ''
|
||||||
for msgtype in msgtypes:
|
for msgtype in msgtypes:
|
||||||
msgtypevar = msgtype.__name__
|
msgtypevar = msgtype.__name__
|
||||||
rtypes = msgtype.get_response_types()
|
rtypes = msgtype.get_response_types()
|
||||||
@ -424,6 +452,7 @@ class MessageProtocol:
|
|||||||
rtypevar = f'Union[{tps}]'
|
rtypevar = f'Union[{tps}]'
|
||||||
else:
|
else:
|
||||||
rtypevar = _filt_tp_name(rtypes[0])
|
rtypevar = _filt_tp_name(rtypes[0])
|
||||||
|
rtypevar = f'{cbgn}{rtypevar}{cend}'
|
||||||
out += (
|
out += (
|
||||||
f'\n'
|
f'\n'
|
||||||
f' @overload\n'
|
f' @overload\n'
|
||||||
@ -441,22 +470,29 @@ class MessageProtocol:
|
|||||||
|
|
||||||
out += (f'\n'
|
out += (f'\n'
|
||||||
f'\n'
|
f'\n'
|
||||||
f'class {ppre}Bound{classname}MessageReceiver:\n'
|
f'class {ppre}Bound{basename}:\n'
|
||||||
f' """Protocol-specific bound receiver."""\n'
|
f' """Protocol-specific bound receiver."""\n'
|
||||||
f'\n'
|
f'\n'
|
||||||
f' def __init__(\n'
|
f' def __init__(\n'
|
||||||
f' self,\n'
|
f' self,\n'
|
||||||
f' obj: Any,\n'
|
f' obj: Any,\n'
|
||||||
f' receiver: {ppre}{classname}MessageReceiver,\n'
|
f' receiver: {ppre}{basename},\n'
|
||||||
f' ) -> None:\n'
|
f' ) -> None:\n'
|
||||||
f' assert obj is not None\n'
|
f' assert obj is not None\n'
|
||||||
f' self._obj = obj\n'
|
f' self._obj = obj\n'
|
||||||
f' self._receiver = receiver\n'
|
f' self._receiver = receiver\n'
|
||||||
f'\n'
|
f'\n'
|
||||||
f' def handle_raw_message(self, message: bytes) -> bytes:\n'
|
f' def handle_raw_message(self, message: bytes) -> bytes:\n'
|
||||||
f' """Handle a raw incoming message."""\n'
|
f' """Handle a raw incoming synchronous message."""\n'
|
||||||
f' return self._receiver.handle_raw_message'
|
f' return self._receiver.handle_raw_message'
|
||||||
f'(self._obj, message)\n')
|
f'(self._obj, message)\n'
|
||||||
|
f'\n'
|
||||||
|
f' async def handle_raw_message_async(self, message: bytes)'
|
||||||
|
f' -> bytes:\n'
|
||||||
|
f' """Handle a raw incoming asynchronous message."""\n'
|
||||||
|
f' return await'
|
||||||
|
f' self._receiver.handle_raw_message_async(\n'
|
||||||
|
f' self._obj, message)\n')
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -471,7 +507,7 @@ class MessageSender:
|
|||||||
class MyClass:
|
class MyClass:
|
||||||
msg = MyMessageSender(some_protocol)
|
msg = MyMessageSender(some_protocol)
|
||||||
|
|
||||||
@msg.send_raw_handler
|
@msg.sendmethod
|
||||||
def send_raw_message(self, message: bytes) -> bytes:
|
def send_raw_message(self, message: bytes) -> bytes:
|
||||||
# Actually send the message here.
|
# Actually send the message here.
|
||||||
|
|
||||||
@ -485,8 +521,10 @@ class MessageSender:
|
|||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
self._send_raw_message_call: Optional[Callable[[Any, bytes],
|
self._send_raw_message_call: Optional[Callable[[Any, bytes],
|
||||||
bytes]] = None
|
bytes]] = None
|
||||||
|
self._send_async_raw_message_call: Optional[Callable[
|
||||||
|
[Any, bytes], Awaitable[bytes]]] = None
|
||||||
|
|
||||||
def send_raw_handler(
|
def send_method(
|
||||||
self, call: Callable[[Any, bytes],
|
self, call: Callable[[Any, bytes],
|
||||||
bytes]) -> Callable[[Any, bytes], bytes]:
|
bytes]) -> Callable[[Any, bytes], bytes]:
|
||||||
"""Function decorator for setting raw send method."""
|
"""Function decorator for setting raw send method."""
|
||||||
@ -494,6 +532,14 @@ class MessageSender:
|
|||||||
self._send_raw_message_call = call
|
self._send_raw_message_call = call
|
||||||
return call
|
return call
|
||||||
|
|
||||||
|
def send_async_method(
|
||||||
|
self, call: Callable[[Any, bytes], Awaitable[bytes]]
|
||||||
|
) -> Callable[[Any, bytes], Awaitable[bytes]]:
|
||||||
|
"""Function decorator for setting raw send-async method."""
|
||||||
|
assert self._send_async_raw_message_call is None
|
||||||
|
self._send_async_raw_message_call = call
|
||||||
|
return call
|
||||||
|
|
||||||
def send(self, bound_obj: Any, message: Message) -> Optional[Response]:
|
def send(self, bound_obj: Any, message: Message) -> Optional[Response]:
|
||||||
"""Send a message and receive a response.
|
"""Send a message and receive a response.
|
||||||
|
|
||||||
@ -510,21 +556,24 @@ class MessageSender:
|
|||||||
or type(response) in type(message).get_response_types())
|
or type(response) in type(message).get_response_types())
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def send_bg(self, bound_obj: Any, message: Message) -> Message:
|
async def send_async(self, bound_obj: Any,
|
||||||
"""Send a message asynchronously and receive a future.
|
message: Message) -> Optional[Response]:
|
||||||
|
|
||||||
The message will be encoded for transport and passed to
|
|
||||||
dispatch_raw_message from a background thread.
|
|
||||||
"""
|
|
||||||
raise RuntimeError('Unimplemented!')
|
|
||||||
|
|
||||||
def send_async(self, bound_obj: Any, message: Message) -> Message:
|
|
||||||
"""Send a message asynchronously using asyncio.
|
"""Send a message asynchronously using asyncio.
|
||||||
|
|
||||||
The message will be encoded for transport and passed to
|
The message will be encoded for transport and passed to
|
||||||
dispatch_raw_message_async.
|
dispatch_raw_message_async.
|
||||||
"""
|
"""
|
||||||
raise RuntimeError('Unimplemented!')
|
if self._send_async_raw_message_call is None:
|
||||||
|
raise RuntimeError('send_async() is unimplemented for this type.')
|
||||||
|
|
||||||
|
msg_encoded = self._protocol.encode_message(message)
|
||||||
|
response_encoded = await self._send_async_raw_message_call(
|
||||||
|
bound_obj, msg_encoded)
|
||||||
|
response = self._protocol.decode_response(response_encoded)
|
||||||
|
assert isinstance(response, (Response, type(None)))
|
||||||
|
assert (response is None
|
||||||
|
or type(response) in type(message).get_response_types())
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
class MessageReceiver:
|
class MessageReceiver:
|
||||||
@ -552,6 +601,8 @@ class MessageReceiver:
|
|||||||
an Exception being raised on the sending end.
|
an Exception being raised on the sending end.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
is_async = False
|
||||||
|
|
||||||
def __init__(self, protocol: MessageProtocol) -> None:
|
def __init__(self, protocol: MessageProtocol) -> None:
|
||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
self._handlers: Dict[Type[Message], Callable] = {}
|
self._handlers: Dict[Type[Message], Callable] = {}
|
||||||
@ -564,6 +615,7 @@ class MessageReceiver:
|
|||||||
The message type handled by the call is determined by its
|
The message type handled by the call is determined by its
|
||||||
type annotation.
|
type annotation.
|
||||||
"""
|
"""
|
||||||
|
# pylint: disable=too-many-locals
|
||||||
# TODO: can use types.GenericAlias in 3.9.
|
# TODO: can use types.GenericAlias in 3.9.
|
||||||
from typing import _GenericAlias # type: ignore
|
from typing import _GenericAlias # type: ignore
|
||||||
from typing import Union, get_type_hints, get_args
|
from typing import Union, get_type_hints, get_args
|
||||||
@ -576,10 +628,19 @@ class MessageReceiver:
|
|||||||
raise ValueError(f'Expected callable signature of {expectedsig};'
|
raise ValueError(f'Expected callable signature of {expectedsig};'
|
||||||
f' got {sig.args}')
|
f' got {sig.args}')
|
||||||
|
|
||||||
|
# Make sure we are only given async methods if we are an async handler
|
||||||
|
# and sync ones otherwise.
|
||||||
|
is_async = inspect.iscoroutinefunction(call)
|
||||||
|
if self.is_async != is_async:
|
||||||
|
msg = ('Expected a sync method; found an async one.' if is_async
|
||||||
|
else 'Expected an async method; found a sync one.')
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Check annotation types to determine what message types we handle.
|
# Check annotation types to determine what message types we handle.
|
||||||
# Return-type annotation can be a Union, but we probably don't
|
# Return-type annotation can be a Union, but we probably don't
|
||||||
# have it available at runtime. Explicitly pull it in.
|
# have it available at runtime. Explicitly pull it in.
|
||||||
anns = get_type_hints(call, localns={'Union': Union})
|
anns = get_type_hints(call, localns={'Union': Union})
|
||||||
|
|
||||||
msgtype = anns.get('msg')
|
msgtype = anns.get('msg')
|
||||||
if not isinstance(msgtype, type):
|
if not isinstance(msgtype, type):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -642,54 +703,71 @@ class MessageReceiver:
|
|||||||
logging.warning(msg)
|
logging.warning(msg)
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
|
def _decode_incoming_message(self,
|
||||||
|
msg: bytes) -> Tuple[Message, Type[Message]]:
|
||||||
|
# Decode the incoming message.
|
||||||
|
msg_decoded = self._protocol.decode_message(msg)
|
||||||
|
msgtype = type(msg_decoded)
|
||||||
|
assert issubclass(msgtype, Message)
|
||||||
|
return msg_decoded, msgtype
|
||||||
|
|
||||||
|
def _encode_response(self, response: Optional[Response],
|
||||||
|
msgtype: Type[Message]) -> bytes:
|
||||||
|
|
||||||
|
# A return value of None equals EmptyResponse.
|
||||||
|
if response is None:
|
||||||
|
response = EmptyResponse()
|
||||||
|
|
||||||
|
# Re-encode the response.
|
||||||
|
assert isinstance(response, Response)
|
||||||
|
# (user should never explicitly return these)
|
||||||
|
assert not isinstance(response, ErrorResponse)
|
||||||
|
assert type(response) in msgtype.get_response_types()
|
||||||
|
return self._protocol.encode_response(response)
|
||||||
|
|
||||||
|
def _handle_raw_message_error(self, exc: Exception) -> bytes:
|
||||||
|
if self._protocol.log_remote_exceptions:
|
||||||
|
logging.exception('Error handling message.')
|
||||||
|
|
||||||
|
# If anything goes wrong, return a ErrorResponse instead.
|
||||||
|
if (isinstance(exc, CleanError)
|
||||||
|
and self._protocol.preserve_clean_errors):
|
||||||
|
err_response = ErrorResponse(error_message=str(exc),
|
||||||
|
error_type=ErrorType.CLEAN)
|
||||||
|
else:
|
||||||
|
err_response = ErrorResponse(
|
||||||
|
error_message=(traceback.format_exc()
|
||||||
|
if self._protocol.trusted_sender else
|
||||||
|
'An unknown error has occurred.'),
|
||||||
|
error_type=ErrorType.OTHER)
|
||||||
|
return self._protocol.encode_response(err_response)
|
||||||
|
|
||||||
def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes:
|
def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes:
|
||||||
"""Decode, handle, and return an encoded response for a message."""
|
"""Decode, handle, and return an encoded response for a message."""
|
||||||
try:
|
try:
|
||||||
# Decode the incoming message.
|
msg_decoded, msgtype = self._decode_incoming_message(msg)
|
||||||
msg_decoded = self._protocol.decode_message(msg)
|
|
||||||
msgtype = type(msg_decoded)
|
|
||||||
assert issubclass(msgtype, Message)
|
|
||||||
|
|
||||||
# Call the proper handler.
|
|
||||||
handler = self._handlers.get(msgtype)
|
handler = self._handlers.get(msgtype)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
|
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
|
||||||
response = handler(bound_obj, msg_decoded)
|
response = handler(bound_obj, msg_decoded)
|
||||||
|
return self._encode_response(response, msgtype)
|
||||||
# A return value of None equals EmptyResponse.
|
|
||||||
if response is None:
|
|
||||||
response = EmptyResponse()
|
|
||||||
|
|
||||||
# Re-encode the response.
|
|
||||||
assert isinstance(response, Response)
|
|
||||||
# (user should never explicitly return these)
|
|
||||||
assert not isinstance(response, ErrorResponse)
|
|
||||||
assert type(response) in msgtype.get_response_types()
|
|
||||||
return self._protocol.encode_response(response)
|
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
return self._handle_raw_message_error(exc)
|
||||||
|
|
||||||
if self._protocol.log_remote_exceptions:
|
async def handle_raw_message_async(self, bound_obj: Any,
|
||||||
logging.exception('Error handling message.')
|
msg: bytes) -> bytes:
|
||||||
|
|
||||||
# If anything goes wrong, return a ErrorResponse instead.
|
|
||||||
if (isinstance(exc, CleanError)
|
|
||||||
and self._protocol.preserve_clean_errors):
|
|
||||||
err_response = ErrorResponse(error_message=str(exc),
|
|
||||||
error_type=ErrorType.CLEAN)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
err_response = ErrorResponse(
|
|
||||||
error_message=(traceback.format_exc()
|
|
||||||
if self._protocol.trusted_sender else
|
|
||||||
'An unknown error has occurred.'),
|
|
||||||
error_type=ErrorType.OTHER)
|
|
||||||
return self._protocol.encode_response(err_response)
|
|
||||||
|
|
||||||
async def handle_raw_message_async(self, msg: bytes) -> bytes:
|
|
||||||
"""Should be called when the receiver gets a message.
|
"""Should be called when the receiver gets a message.
|
||||||
|
|
||||||
The return value is the raw response to the message.
|
The return value is the raw response to the message.
|
||||||
"""
|
"""
|
||||||
raise RuntimeError('Unimplemented!')
|
try:
|
||||||
|
msg_decoded, msgtype = self._decode_incoming_message(msg)
|
||||||
|
handler = self._handlers.get(msgtype)
|
||||||
|
if handler is None:
|
||||||
|
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
|
||||||
|
response = await handler(bound_obj, msg_decoded)
|
||||||
|
return self._encode_response(response, msgtype)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
return self._handle_raw_message_error(exc)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user