mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-01-27 09:23:12 +08:00
work on async messaging
This commit is contained in:
parent
b46b1cbca8
commit
c8a6c3733d
6
.idea/dictionaries/ericf.xml
generated
6
.idea/dictionaries/ericf.xml
generated
@ -159,6 +159,7 @@
|
||||
<w>availmins</w>
|
||||
<w>availplug</w>
|
||||
<w>aval</w>
|
||||
<w>awaitable</w>
|
||||
<w>axismotion</w>
|
||||
<w>bacfg</w>
|
||||
<w>backgrounded</w>
|
||||
@ -303,6 +304,7 @@
|
||||
<w>capturetheflag</w>
|
||||
<w>carentity</w>
|
||||
<w>cashregistersound</w>
|
||||
<w>cbgn</w>
|
||||
<w>cbits</w>
|
||||
<w>cbot</w>
|
||||
<w>cbtn</w>
|
||||
@ -313,6 +315,7 @@
|
||||
<w>cdrk</w>
|
||||
<w>cdull</w>
|
||||
<w>cdval</w>
|
||||
<w>cend</w>
|
||||
<w>centeuro</w>
|
||||
<w>centiseconds</w>
|
||||
<w>cfconfig</w>
|
||||
@ -1869,6 +1872,8 @@
|
||||
<w>rawpaths</w>
|
||||
<w>rcade</w>
|
||||
<w>rcfile</w>
|
||||
<w>rcva</w>
|
||||
<w>rcvs</w>
|
||||
<w>rdict</w>
|
||||
<w>rdir</w>
|
||||
<w>readline</w>
|
||||
@ -1997,6 +2002,7 @@
|
||||
<w>selwidget</w>
|
||||
<w>selwidgets</w>
|
||||
<w>sendable</w>
|
||||
<w>sendmethod</w>
|
||||
<w>senze</w>
|
||||
<w>seqtype</w>
|
||||
<w>seqtypestr</w>
|
||||
|
||||
@ -69,6 +69,7 @@
|
||||
<w>availmins</w>
|
||||
<w>avel</w>
|
||||
<w>avels</w>
|
||||
<w>awaitable</w>
|
||||
<w>axismotion</w>
|
||||
<w>backgrounded</w>
|
||||
<w>backgrounding</w>
|
||||
@ -155,10 +156,12 @@
|
||||
<w>cancelbtn</w>
|
||||
<w>capitan</w>
|
||||
<w>cargs</w>
|
||||
<w>cbgn</w>
|
||||
<w>cbtnoffs</w>
|
||||
<w>ccdd</w>
|
||||
<w>ccontext</w>
|
||||
<w>ccylinder</w>
|
||||
<w>cend</w>
|
||||
<w>centiseconds</w>
|
||||
<w>cfgdir</w>
|
||||
<w>cfgpath</w>
|
||||
@ -860,6 +863,8 @@
|
||||
<w>rasterizer</w>
|
||||
<w>rawkey</w>
|
||||
<w>rcade</w>
|
||||
<w>rcva</w>
|
||||
<w>rcvs</w>
|
||||
<w>reaaaly</w>
|
||||
<w>readset</w>
|
||||
<w>realloc</w>
|
||||
@ -922,6 +927,7 @@
|
||||
<w>selwidget</w>
|
||||
<w>selwidgets</w>
|
||||
<w>sendable</w>
|
||||
<w>sendmethod</w>
|
||||
<w>seqlen</w>
|
||||
<w>seqtype</w>
|
||||
<w>seqtypestr</w>
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
|
||||
<h4><em>last updated on 2021-09-19 for Ballistica version 1.6.5 build 20393</em></h4>
|
||||
<h4><em>last updated on 2021-09-21 for Ballistica version 1.6.5 build 20393</em></h4>
|
||||
<p>This page documents the Python classes and functions in the 'ba' module,
|
||||
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
|
||||
<hr>
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, overload
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -17,55 +18,55 @@ from efro.message import (Message, Response, MessageProtocol, MessageSender,
|
||||
MessageReceiver)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import List, Type, Any, Callable, Union, Optional
|
||||
from typing import List, Type, Any, Callable, Union, Optional, Awaitable
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TMessage1(Message):
|
||||
class _TMsg1(Message):
|
||||
"""Just testing."""
|
||||
ival: int
|
||||
|
||||
@classmethod
|
||||
def get_response_types(cls) -> List[Type[Response]]:
|
||||
return [_TResponse1]
|
||||
return [_TResp1]
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TMessage2(Message):
|
||||
class _TMsg2(Message):
|
||||
"""Just testing."""
|
||||
sval: str
|
||||
|
||||
@classmethod
|
||||
def get_response_types(cls) -> List[Type[Response]]:
|
||||
return [_TResponse1, _TResponse2]
|
||||
return [_TResp1, _TResp2]
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TMessage3(Message):
|
||||
class _TMsg3(Message):
|
||||
"""Just testing."""
|
||||
sval: str
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TResponse1(Response):
|
||||
class _TResp1(Response):
|
||||
"""Just testing."""
|
||||
bval: bool
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TResponse2(Response):
|
||||
class _TResp2(Response):
|
||||
"""Just testing."""
|
||||
fval: float
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TResponse3(Message):
|
||||
class _TResp3(Message):
|
||||
"""Just testing."""
|
||||
fval: float
|
||||
|
||||
@ -91,55 +92,73 @@ class _BoundTestMessageSender:
|
||||
self._sender = sender
|
||||
|
||||
@overload
|
||||
def send(self, message: _TMessage1) -> _TResponse1:
|
||||
def send(self, message: _TMsg1) -> _TResp1:
|
||||
...
|
||||
|
||||
@overload
|
||||
def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]:
|
||||
def send(self, message: _TMsg2) -> Union[_TResp1, _TResp2]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def send(self, message: _TMessage3) -> None:
|
||||
def send(self, message: _TMsg3) -> None:
|
||||
...
|
||||
|
||||
def send(self, message: Message) -> Optional[Response]:
|
||||
"""Send a message."""
|
||||
"""Send a message synchronously."""
|
||||
return self._sender.send(self._obj, message)
|
||||
|
||||
@overload
|
||||
async def send_async(self, message: _TMsg1) -> _TResp1:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def send_async(self, message: _TMsg2) -> Union[_TResp1, _TResp2]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def send_async(self, message: _TMsg3) -> None:
|
||||
...
|
||||
|
||||
async def send_async(self, message: Message) -> Optional[Response]:
|
||||
"""Send a message asynchronously."""
|
||||
return await self._sender.send_async(self._obj, message)
|
||||
|
||||
|
||||
# SEND_CODE_TEST_END
|
||||
# RCV_CODE_TEST_BEGIN
|
||||
# RCVS_CODE_TEST_BEGIN
|
||||
|
||||
|
||||
class _TestMessageReceiver(MessageReceiver):
|
||||
"""Protocol-specific receiver."""
|
||||
class _TestSyncMessageReceiver(MessageReceiver):
|
||||
"""Protocol-specific synchronous receiver."""
|
||||
|
||||
is_async = False
|
||||
|
||||
def __get__(
|
||||
self,
|
||||
obj: Any,
|
||||
type_in: Any = None,
|
||||
) -> _BoundTestMessageReceiver:
|
||||
return _BoundTestMessageReceiver(obj, self)
|
||||
) -> _BoundTestSyncMessageReceiver:
|
||||
return _BoundTestSyncMessageReceiver(obj, self)
|
||||
|
||||
@overload
|
||||
def handler(
|
||||
self,
|
||||
call: Callable[[Any, _TMessage1], _TResponse1],
|
||||
) -> Callable[[Any, _TMessage1], _TResponse1]:
|
||||
call: Callable[[Any, _TMsg1], _TResp1],
|
||||
) -> Callable[[Any, _TMsg1], _TResp1]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def handler(
|
||||
self,
|
||||
call: Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]],
|
||||
) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]:
|
||||
call: Callable[[Any, _TMsg2], Union[_TResp1, _TResp2]],
|
||||
) -> Callable[[Any, _TMsg2], Union[_TResp1, _TResp2]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def handler(
|
||||
self,
|
||||
call: Callable[[Any, _TMessage3], None],
|
||||
) -> Callable[[Any, _TMessage3], None]:
|
||||
call: Callable[[Any, _TMsg3], None],
|
||||
) -> Callable[[Any, _TMsg3], None]:
|
||||
...
|
||||
|
||||
def handler(self, call: Callable) -> Callable:
|
||||
@ -148,34 +167,104 @@ class _TestMessageReceiver(MessageReceiver):
|
||||
return call
|
||||
|
||||
|
||||
class _BoundTestMessageReceiver:
|
||||
class _BoundTestSyncMessageReceiver:
|
||||
"""Protocol-specific bound receiver."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Any,
|
||||
receiver: _TestMessageReceiver,
|
||||
receiver: _TestSyncMessageReceiver,
|
||||
) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._receiver = receiver
|
||||
|
||||
def handle_raw_message(self, message: bytes) -> bytes:
|
||||
"""Handle a raw incoming message."""
|
||||
"""Handle a raw incoming synchronous message."""
|
||||
return self._receiver.handle_raw_message(self._obj, message)
|
||||
|
||||
async def handle_raw_message_async(self, message: bytes) -> bytes:
|
||||
"""Handle a raw incoming asynchronous message."""
|
||||
return await self._receiver.handle_raw_message_async(
|
||||
self._obj, message)
|
||||
|
||||
# RCV_CODE_TEST_END
|
||||
|
||||
# RCVS_CODE_TEST_END
|
||||
# RCVA_CODE_TEST_BEGIN
|
||||
|
||||
|
||||
class _TestAsyncMessageReceiver(MessageReceiver):
|
||||
"""Protocol-specific asynchronous receiver."""
|
||||
|
||||
is_async = True
|
||||
|
||||
def __get__(
|
||||
self,
|
||||
obj: Any,
|
||||
type_in: Any = None,
|
||||
) -> _BoundTestAsyncMessageReceiver:
|
||||
return _BoundTestAsyncMessageReceiver(obj, self)
|
||||
|
||||
@overload
|
||||
def handler(
|
||||
self,
|
||||
call: Callable[[Any, _TMsg1], Awaitable[_TResp1]],
|
||||
) -> Callable[[Any, _TMsg1], Awaitable[_TResp1]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def handler(
|
||||
self,
|
||||
call: Callable[[Any, _TMsg2], Awaitable[Union[_TResp1, _TResp2]]],
|
||||
) -> Callable[[Any, _TMsg2], Awaitable[Union[_TResp1, _TResp2]]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def handler(
|
||||
self,
|
||||
call: Callable[[Any, _TMsg3], Awaitable[None]],
|
||||
) -> Callable[[Any, _TMsg3], Awaitable[None]]:
|
||||
...
|
||||
|
||||
def handler(self, call: Callable) -> Callable:
|
||||
"""Decorator to register message handlers."""
|
||||
self.register_handler(call)
|
||||
return call
|
||||
|
||||
|
||||
class _BoundTestAsyncMessageReceiver:
|
||||
"""Protocol-specific bound receiver."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Any,
|
||||
receiver: _TestAsyncMessageReceiver,
|
||||
) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._receiver = receiver
|
||||
|
||||
def handle_raw_message(self, message: bytes) -> bytes:
|
||||
"""Handle a raw incoming synchronous message."""
|
||||
return self._receiver.handle_raw_message(self._obj, message)
|
||||
|
||||
async def handle_raw_message_async(self, message: bytes) -> bytes:
|
||||
"""Handle a raw incoming asynchronous message."""
|
||||
return await self._receiver.handle_raw_message_async(
|
||||
self._obj, message)
|
||||
|
||||
|
||||
# RCVA_CODE_TEST_END
|
||||
|
||||
TEST_PROTOCOL = MessageProtocol(
|
||||
message_types={
|
||||
0: _TMessage1,
|
||||
1: _TMessage2,
|
||||
2: _TMessage3,
|
||||
0: _TMsg1,
|
||||
1: _TMsg2,
|
||||
2: _TMsg3,
|
||||
},
|
||||
response_types={
|
||||
0: _TResponse1,
|
||||
1: _TResponse2,
|
||||
0: _TResp1,
|
||||
1: _TResp2,
|
||||
},
|
||||
trusted_sender=True,
|
||||
log_remote_exceptions=False,
|
||||
@ -185,20 +274,21 @@ TEST_PROTOCOL = MessageProtocol(
|
||||
def test_protocol_creation() -> None:
|
||||
"""Test protocol creation."""
|
||||
|
||||
# This should fail because _TMessage1 can return _TResponse1 which
|
||||
# This should fail because _TMsg1 can return _TResp1 which
|
||||
# is not given an id here.
|
||||
with pytest.raises(ValueError):
|
||||
_protocol = MessageProtocol(message_types={0: _TMessage1},
|
||||
response_types={0: _TResponse2})
|
||||
_protocol = MessageProtocol(message_types={0: _TMsg1},
|
||||
response_types={0: _TResp2})
|
||||
|
||||
# Now it should work.
|
||||
_protocol = MessageProtocol(message_types={0: _TMessage1},
|
||||
response_types={0: _TResponse1})
|
||||
_protocol = MessageProtocol(message_types={0: _TMsg1},
|
||||
response_types={0: _TResp1})
|
||||
|
||||
|
||||
def test_sender_module_embedded() -> None:
|
||||
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||
smod = TEST_PROTOCOL.create_sender_module('Test', private=True)
|
||||
smod = TEST_PROTOCOL.create_sender_module('TestMessageSender',
|
||||
private=True)
|
||||
|
||||
# Clip everything up to our first class declaration.
|
||||
lines = smod.splitlines()
|
||||
@ -218,13 +308,15 @@ def test_sender_module_embedded() -> None:
|
||||
' See test stdout for new code.')
|
||||
|
||||
|
||||
def test_receiver_module_embedded() -> None:
|
||||
def test_receiver_module_sync_embedded() -> None:
|
||||
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||
smod = TEST_PROTOCOL.create_receiver_module('Test', private=True)
|
||||
smod = TEST_PROTOCOL.create_receiver_module('TestSyncMessageReceiver',
|
||||
is_async=False,
|
||||
private=True)
|
||||
|
||||
# Clip everything up to our first class declaration.
|
||||
lines = smod.splitlines()
|
||||
classline = lines.index('class _TestMessageReceiver(MessageReceiver):')
|
||||
classline = lines.index('class _TestSyncMessageReceiver(MessageReceiver):')
|
||||
clipped = '\n'.join(lines[classline:])
|
||||
|
||||
# This snippet should match what we've got embedded above;
|
||||
@ -232,12 +324,39 @@ def test_receiver_module_embedded() -> None:
|
||||
with open(__file__, encoding='utf-8') as infile:
|
||||
ourcode = infile.read()
|
||||
|
||||
emb = f'# RCV_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCV_CODE_TEST_END\n'
|
||||
emb = f'# RCVS_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVS_CODE_TEST_END\n'
|
||||
if emb not in ourcode:
|
||||
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
|
||||
raise RuntimeError('Generated sender module does not match embedded;'
|
||||
' test code needs to be updated.'
|
||||
' See test stdout for new code.')
|
||||
print(f'EXPECTED SYNC RECEIVER EMBEDDED CODE:\n{emb}')
|
||||
raise RuntimeError(
|
||||
'Generated sync receiver module does not match embedded;'
|
||||
' test code needs to be updated.'
|
||||
' See test stdout for new code.')
|
||||
|
||||
|
||||
def test_receiver_module_async_embedded() -> None:
|
||||
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||
smod = TEST_PROTOCOL.create_receiver_module('TestAsyncMessageReceiver',
|
||||
is_async=True,
|
||||
private=True)
|
||||
|
||||
# Clip everything up to our first class declaration.
|
||||
lines = smod.splitlines()
|
||||
classline = lines.index(
|
||||
'class _TestAsyncMessageReceiver(MessageReceiver):')
|
||||
clipped = '\n'.join(lines[classline:])
|
||||
|
||||
# This snippet should match what we've got embedded above;
|
||||
# If not then we need to update our embedded version.
|
||||
with open(__file__, encoding='utf-8') as infile:
|
||||
ourcode = infile.read()
|
||||
|
||||
emb = f'# RCVA_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCVA_CODE_TEST_END\n'
|
||||
if emb not in ourcode:
|
||||
print(f'EXPECTED ASYNC RECEIVER EMBEDDED CODE:\n{emb}')
|
||||
raise RuntimeError(
|
||||
'Generated async receiver module does not match embedded;'
|
||||
' test code needs to be updated.'
|
||||
' See test stdout for new code.')
|
||||
|
||||
|
||||
def test_receiver_creation() -> None:
|
||||
@ -251,13 +370,13 @@ def test_receiver_creation() -> None:
|
||||
class _TestClassR:
|
||||
"""Test class incorporating receive functionality."""
|
||||
|
||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
||||
receiver = _TestSyncMessageReceiver(TEST_PROTOCOL)
|
||||
|
||||
@receiver.handler
|
||||
def handle_test_message_2(self, msg: _TMessage2) -> _TResponse2:
|
||||
def handle_test_message_2(self, msg: _TMsg2) -> _TResp2:
|
||||
"""Test."""
|
||||
del msg # Unused
|
||||
return _TResponse2(fval=1.2)
|
||||
return _TResp2(fval=1.2)
|
||||
|
||||
# Validation should fail because not all message types in the
|
||||
# protocol are handled.
|
||||
@ -266,13 +385,13 @@ def test_receiver_creation() -> None:
|
||||
class _TestClassR2:
|
||||
"""Test class incorporating receive functionality."""
|
||||
|
||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
||||
receiver = _TestSyncMessageReceiver(TEST_PROTOCOL)
|
||||
|
||||
# Checks that we've added handlers for all message types, etc.
|
||||
receiver.validate()
|
||||
|
||||
|
||||
def test_synchronous_messaging() -> None:
|
||||
def test_full_pipeline() -> None:
|
||||
"""Test the full pipeline."""
|
||||
|
||||
# Define a class that can send messages and one that can receive them.
|
||||
@ -281,66 +400,119 @@ def test_synchronous_messaging() -> None:
|
||||
|
||||
msg = _TestMessageSender(TEST_PROTOCOL)
|
||||
|
||||
def __init__(self, target: TestClassR) -> None:
|
||||
def __init__(self, target: Union[TestClassRSync,
|
||||
TestClassRAsync]) -> None:
|
||||
self._target = target
|
||||
|
||||
@msg.send_raw_handler
|
||||
@msg.send_method
|
||||
def _send_raw_message(self, data: bytes) -> bytes:
|
||||
"""Test."""
|
||||
"""Handle synchronous sending of raw message data."""
|
||||
# Just talk directly to the receiver for this example.
|
||||
# (currently only support synchronous receivers)
|
||||
assert isinstance(self._target, TestClassRSync)
|
||||
return self._target.receiver.handle_raw_message(data)
|
||||
|
||||
class TestClassR:
|
||||
"""Test class incorporating receive functionality."""
|
||||
@msg.send_async_method
|
||||
async def _send_raw_message_async(self, data: bytes) -> bytes:
|
||||
"""Handle asynchronous sending of raw message data."""
|
||||
# Just talk directly to the receiver for this example.
|
||||
# (we can do sync or async receivers)
|
||||
if isinstance(self._target, TestClassRSync):
|
||||
return self._target.receiver.handle_raw_message(data)
|
||||
return await self._target.receiver.handle_raw_message_async(data)
|
||||
|
||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
||||
class TestClassRSync:
|
||||
"""Test class incorporating synchronous receive functionality."""
|
||||
|
||||
receiver = _TestSyncMessageReceiver(TEST_PROTOCOL)
|
||||
|
||||
@receiver.handler
|
||||
def handle_test_message_1(self, msg: _TMessage1) -> _TResponse1:
|
||||
def handle_test_message_1(self, msg: _TMsg1) -> _TResp1:
|
||||
"""Test."""
|
||||
if msg.ival == 1:
|
||||
raise CleanError('Testing Clean Error')
|
||||
if msg.ival == 2:
|
||||
raise RuntimeError('Testing Runtime Error')
|
||||
return _TResponse1(bval=True)
|
||||
return _TResp1(bval=True)
|
||||
|
||||
@receiver.handler
|
||||
def handle_test_message_2(
|
||||
self, msg: _TMessage2) -> Union[_TResponse1, _TResponse2]:
|
||||
def handle_test_message_2(self,
|
||||
msg: _TMsg2) -> Union[_TResp1, _TResp2]:
|
||||
"""Test."""
|
||||
del msg # Unused
|
||||
return _TResponse2(fval=1.2)
|
||||
return _TResp2(fval=1.2)
|
||||
|
||||
@receiver.handler
|
||||
def handle_test_message_3(self, msg: _TMessage3) -> None:
|
||||
def handle_test_message_3(self, msg: _TMsg3) -> None:
|
||||
"""Test."""
|
||||
del msg # Unused
|
||||
|
||||
receiver.validate()
|
||||
|
||||
obj_r = TestClassR()
|
||||
obj = TestClassS(target=obj_r)
|
||||
class TestClassRAsync:
|
||||
"""Test class incorporating asynchronous receive functionality."""
|
||||
|
||||
response = obj.msg.send(_TMessage1(ival=0))
|
||||
assert isinstance(response, _TResponse1)
|
||||
receiver = _TestAsyncMessageReceiver(TEST_PROTOCOL)
|
||||
|
||||
response2 = obj.msg.send(_TMessage2(sval='rah'))
|
||||
assert isinstance(response2, (_TResponse1, _TResponse2))
|
||||
@receiver.handler
|
||||
async def handle_test_message_1(self, msg: _TMsg1) -> _TResp1:
|
||||
"""Test."""
|
||||
if msg.ival == 1:
|
||||
raise CleanError('Testing Clean Error')
|
||||
if msg.ival == 2:
|
||||
raise RuntimeError('Testing Runtime Error')
|
||||
return _TResp1(bval=True)
|
||||
|
||||
response3 = obj.msg.send(_TMessage3(sval='rah'))
|
||||
assert response3 is None
|
||||
@receiver.handler
|
||||
async def handle_test_message_2(
|
||||
self, msg: _TMsg2) -> Union[_TResp1, _TResp2]:
|
||||
"""Test."""
|
||||
del msg # Unused
|
||||
return _TResp2(fval=1.2)
|
||||
|
||||
# Make sure static typing lines up too.
|
||||
@receiver.handler
|
||||
async def handle_test_message_3(self, msg: _TMsg3) -> None:
|
||||
"""Test."""
|
||||
del msg # Unused
|
||||
|
||||
receiver.validate()
|
||||
|
||||
obj_r_sync = TestClassRSync()
|
||||
obj_r_async = TestClassRAsync()
|
||||
obj = TestClassS(target=obj_r_sync)
|
||||
obj2 = TestClassS(target=obj_r_async)
|
||||
|
||||
# Test sends (of sync and async varieties).
|
||||
response1 = obj.msg.send(_TMsg1(ival=0))
|
||||
response2 = obj.msg.send(_TMsg2(sval='rah'))
|
||||
response3 = obj.msg.send(_TMsg3(sval='rah'))
|
||||
response4 = asyncio.run(obj.msg.send_async(_TMsg1(ival=0)))
|
||||
|
||||
# Make sure static typing lines up with what we expect.
|
||||
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
|
||||
assert static_type_equals(response, _TResponse1)
|
||||
assert static_type_equals(response1, _TResp1)
|
||||
assert static_type_equals(response3, None)
|
||||
|
||||
assert isinstance(response1, _TResp1)
|
||||
assert isinstance(response2, (_TResp1, _TResp2))
|
||||
assert response3 is None
|
||||
assert isinstance(response4, _TResp1)
|
||||
|
||||
# Remote CleanErrors should come across locally as the same.
|
||||
try:
|
||||
_response4 = obj.msg.send(_TMessage1(ival=1))
|
||||
_response5 = obj.msg.send(_TMsg1(ival=1))
|
||||
except Exception as exc:
|
||||
assert isinstance(exc, CleanError)
|
||||
assert str(exc) == 'Testing Clean Error'
|
||||
|
||||
# Other remote errors should result in RemoteError.
|
||||
with pytest.raises(RemoteError):
|
||||
_response4 = obj.msg.send(_TMessage1(ival=2))
|
||||
_response5 = obj.msg.send(_TMsg1(ival=2))
|
||||
|
||||
# Now test sends to async handlers.
|
||||
response6 = asyncio.run(obj2.msg.send_async(_TMsg1(ival=0)))
|
||||
assert isinstance(response6, _TResp1)
|
||||
|
||||
# Make sure static typing lines up with what we expect.
|
||||
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
|
||||
assert static_type_equals(response6, _TResp1)
|
||||
|
||||
@ -22,7 +22,7 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
|
||||
Sequence, Union)
|
||||
Sequence, Union, Awaitable)
|
||||
from efro.error import CommunicationError
|
||||
|
||||
TM = TypeVar('TM', bound='MessageSender')
|
||||
@ -300,7 +300,7 @@ class MessageProtocol:
|
||||
return out
|
||||
|
||||
def create_sender_module(self,
|
||||
classname: str,
|
||||
basename: str,
|
||||
private: bool = False) -> str:
|
||||
""""Create a Python module defining a MessageSender subclass.
|
||||
|
||||
@ -308,29 +308,35 @@ class MessageProtocol:
|
||||
for the varieties of send calls for message/response types defined
|
||||
in the protocol.
|
||||
|
||||
Class names are based on basename; a basename 'FooSender' will
|
||||
result in classes FooSender and BoundFooSender.
|
||||
|
||||
If 'private' is True, class-names will be prefixed with an '_'.
|
||||
|
||||
Note that line lengths are not clipped, so output may need to be
|
||||
run through a formatter to prevent lint warnings about excessive
|
||||
line lengths.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
ppre = '_' if private else ''
|
||||
out = self._get_module_header('sender')
|
||||
out += (f'class {ppre}{classname}MessageSender(MessageSender):\n'
|
||||
out += (f'class {ppre}{basename}(MessageSender):\n'
|
||||
f' """Protocol-specific sender."""\n'
|
||||
f'\n'
|
||||
f' def __get__(self,\n'
|
||||
f' obj: Any,\n'
|
||||
f' type_in: Any = None)'
|
||||
f' -> {ppre}Bound{classname}MessageSender:\n'
|
||||
f' return {ppre}Bound{classname}MessageSender'
|
||||
f' -> {ppre}Bound{basename}:\n'
|
||||
f' return {ppre}Bound{basename}'
|
||||
f'(obj, self)\n'
|
||||
f'\n'
|
||||
f'\n'
|
||||
f'class {ppre}Bound{classname}MessageSender:\n'
|
||||
f'class {ppre}Bound{basename}:\n'
|
||||
f' """Protocol-specific bound sender."""\n'
|
||||
f'\n'
|
||||
f' def __init__(self, obj: Any,'
|
||||
f' sender: {ppre}{classname}MessageSender) -> None:\n'
|
||||
f' sender: {ppre}{basename}) -> None:\n'
|
||||
f' assert obj is not None\n'
|
||||
f' self._obj = obj\n'
|
||||
f' self._sender = sender\n')
|
||||
@ -352,29 +358,37 @@ class MessageProtocol:
|
||||
return 'None' if rtype is EmptyResponse else rtype.__name__
|
||||
|
||||
if len(msgtypes) > 1:
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
|
||||
rtypevar = f'Union[{tps}]'
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
for async_pass in False, True:
|
||||
pfx = 'async ' if async_pass else ''
|
||||
sfx = '_async' if async_pass else ''
|
||||
awt = 'await ' if async_pass else ''
|
||||
how = 'asynchronously' if async_pass else 'synchronously'
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
tps = ', '.join(_filt_tp_name(t) for t in rtypes)
|
||||
rtypevar = f'Union[{tps}]'
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
out += (f'\n'
|
||||
f' @overload\n'
|
||||
f' {pfx}def send{sfx}(self,'
|
||||
f' message: {msgtypevar})'
|
||||
f' -> {rtypevar}:\n'
|
||||
f' ...\n')
|
||||
out += (f'\n'
|
||||
f' @overload\n'
|
||||
f' def send(self, message: {msgtypevar})'
|
||||
f' -> {rtypevar}:\n'
|
||||
f' ...\n')
|
||||
out += ('\n'
|
||||
' def send(self, message: Message)'
|
||||
' -> Optional[Response]:\n'
|
||||
' """Send a message."""\n'
|
||||
' return self._sender.send(self._obj, message)\n')
|
||||
f' {pfx}def send{sfx}(self, message: Message)'
|
||||
f' -> Optional[Response]:\n'
|
||||
f' """Send a message {how}."""\n'
|
||||
f' return {awt}self._sender.'
|
||||
f'send{sfx}(self._obj, message)\n')
|
||||
|
||||
return out
|
||||
|
||||
def create_receiver_module(self,
|
||||
classname: str,
|
||||
basename: str,
|
||||
is_async: bool,
|
||||
private: bool = False) -> str:
|
||||
""""Create a Python module defining a MessageReceiver subclass.
|
||||
|
||||
@ -382,21 +396,33 @@ class MessageProtocol:
|
||||
for the register method for message/response types defined in
|
||||
the protocol.
|
||||
|
||||
Class names are based on basename; a basename 'FooReceiver' will
|
||||
result in FooReceiver and BoundFooReceiver.
|
||||
|
||||
If 'is_async' is True, handle_raw_message() will be an async method
|
||||
and the @handler decorator will expect async methods.
|
||||
|
||||
If 'private' is True, class-names will be prefixed with an '_'.
|
||||
|
||||
Note that line lengths are not clipped, so output may need to be
|
||||
run through a formatter to prevent lint warnings about excessive
|
||||
line lengths.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
desc = 'asynchronous' if is_async else 'synchronous'
|
||||
ppre = '_' if private else ''
|
||||
out = self._get_module_header('receiver')
|
||||
out += (f'class {ppre}{classname}MessageReceiver(MessageReceiver):\n'
|
||||
f' """Protocol-specific receiver."""\n'
|
||||
out += (f'class {ppre}{basename}(MessageReceiver):\n'
|
||||
f' """Protocol-specific {desc} receiver."""\n'
|
||||
f'\n'
|
||||
f' is_async = {is_async}\n'
|
||||
f'\n'
|
||||
f' def __get__(\n'
|
||||
f' self,\n'
|
||||
f' obj: Any,\n'
|
||||
f' type_in: Any = None,\n'
|
||||
f' ) -> {ppre}Bound{classname}MessageReceiver:\n'
|
||||
f' return {ppre}Bound{classname}MessageReceiver('
|
||||
f' ) -> {ppre}Bound{basename}:\n'
|
||||
f' return {ppre}Bound{basename}('
|
||||
f'obj, self)\n')
|
||||
|
||||
# Define handler() overloads for all registered message types.
|
||||
@ -416,6 +442,8 @@ class MessageProtocol:
|
||||
return 'None' if rtype is EmptyResponse else rtype.__name__
|
||||
|
||||
if len(msgtypes) > 1:
|
||||
cbgn = 'Awaitable[' if is_async else ''
|
||||
cend = ']' if is_async else ''
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
@ -424,6 +452,7 @@ class MessageProtocol:
|
||||
rtypevar = f'Union[{tps}]'
|
||||
else:
|
||||
rtypevar = _filt_tp_name(rtypes[0])
|
||||
rtypevar = f'{cbgn}{rtypevar}{cend}'
|
||||
out += (
|
||||
f'\n'
|
||||
f' @overload\n'
|
||||
@ -441,22 +470,29 @@ class MessageProtocol:
|
||||
|
||||
out += (f'\n'
|
||||
f'\n'
|
||||
f'class {ppre}Bound{classname}MessageReceiver:\n'
|
||||
f'class {ppre}Bound{basename}:\n'
|
||||
f' """Protocol-specific bound receiver."""\n'
|
||||
f'\n'
|
||||
f' def __init__(\n'
|
||||
f' self,\n'
|
||||
f' obj: Any,\n'
|
||||
f' receiver: {ppre}{classname}MessageReceiver,\n'
|
||||
f' receiver: {ppre}{basename},\n'
|
||||
f' ) -> None:\n'
|
||||
f' assert obj is not None\n'
|
||||
f' self._obj = obj\n'
|
||||
f' self._receiver = receiver\n'
|
||||
f'\n'
|
||||
f' def handle_raw_message(self, message: bytes) -> bytes:\n'
|
||||
f' """Handle a raw incoming message."""\n'
|
||||
f' """Handle a raw incoming synchronous message."""\n'
|
||||
f' return self._receiver.handle_raw_message'
|
||||
f'(self._obj, message)\n')
|
||||
f'(self._obj, message)\n'
|
||||
f'\n'
|
||||
f' async def handle_raw_message_async(self, message: bytes)'
|
||||
f' -> bytes:\n'
|
||||
f' """Handle a raw incoming asynchronous message."""\n'
|
||||
f' return await'
|
||||
f' self._receiver.handle_raw_message_async(\n'
|
||||
f' self._obj, message)\n')
|
||||
|
||||
return out
|
||||
|
||||
@ -471,7 +507,7 @@ class MessageSender:
|
||||
class MyClass:
|
||||
msg = MyMessageSender(some_protocol)
|
||||
|
||||
@msg.send_raw_handler
|
||||
@msg.sendmethod
|
||||
def send_raw_message(self, message: bytes) -> bytes:
|
||||
# Actually send the message here.
|
||||
|
||||
@ -485,8 +521,10 @@ class MessageSender:
|
||||
self._protocol = protocol
|
||||
self._send_raw_message_call: Optional[Callable[[Any, bytes],
|
||||
bytes]] = None
|
||||
self._send_async_raw_message_call: Optional[Callable[
|
||||
[Any, bytes], Awaitable[bytes]]] = None
|
||||
|
||||
def send_raw_handler(
|
||||
def send_method(
|
||||
self, call: Callable[[Any, bytes],
|
||||
bytes]) -> Callable[[Any, bytes], bytes]:
|
||||
"""Function decorator for setting raw send method."""
|
||||
@ -494,6 +532,14 @@ class MessageSender:
|
||||
self._send_raw_message_call = call
|
||||
return call
|
||||
|
||||
def send_async_method(
|
||||
self, call: Callable[[Any, bytes], Awaitable[bytes]]
|
||||
) -> Callable[[Any, bytes], Awaitable[bytes]]:
|
||||
"""Function decorator for setting raw send-async method."""
|
||||
assert self._send_async_raw_message_call is None
|
||||
self._send_async_raw_message_call = call
|
||||
return call
|
||||
|
||||
def send(self, bound_obj: Any, message: Message) -> Optional[Response]:
|
||||
"""Send a message and receive a response.
|
||||
|
||||
@ -510,21 +556,24 @@ class MessageSender:
|
||||
or type(response) in type(message).get_response_types())
|
||||
return response
|
||||
|
||||
def send_bg(self, bound_obj: Any, message: Message) -> Message:
|
||||
"""Send a message asynchronously and receive a future.
|
||||
|
||||
The message will be encoded for transport and passed to
|
||||
dispatch_raw_message from a background thread.
|
||||
"""
|
||||
raise RuntimeError('Unimplemented!')
|
||||
|
||||
def send_async(self, bound_obj: Any, message: Message) -> Message:
|
||||
async def send_async(self, bound_obj: Any,
|
||||
message: Message) -> Optional[Response]:
|
||||
"""Send a message asynchronously using asyncio.
|
||||
|
||||
The message will be encoded for transport and passed to
|
||||
dispatch_raw_message_async.
|
||||
"""
|
||||
raise RuntimeError('Unimplemented!')
|
||||
if self._send_async_raw_message_call is None:
|
||||
raise RuntimeError('send_async() is unimplemented for this type.')
|
||||
|
||||
msg_encoded = self._protocol.encode_message(message)
|
||||
response_encoded = await self._send_async_raw_message_call(
|
||||
bound_obj, msg_encoded)
|
||||
response = self._protocol.decode_response(response_encoded)
|
||||
assert isinstance(response, (Response, type(None)))
|
||||
assert (response is None
|
||||
or type(response) in type(message).get_response_types())
|
||||
return response
|
||||
|
||||
|
||||
class MessageReceiver:
|
||||
@ -552,6 +601,8 @@ class MessageReceiver:
|
||||
an Exception being raised on the sending end.
|
||||
"""
|
||||
|
||||
is_async = False
|
||||
|
||||
def __init__(self, protocol: MessageProtocol) -> None:
|
||||
self._protocol = protocol
|
||||
self._handlers: Dict[Type[Message], Callable] = {}
|
||||
@ -564,6 +615,7 @@ class MessageReceiver:
|
||||
The message type handled by the call is determined by its
|
||||
type annotation.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
# TODO: can use types.GenericAlias in 3.9.
|
||||
from typing import _GenericAlias # type: ignore
|
||||
from typing import Union, get_type_hints, get_args
|
||||
@ -576,10 +628,19 @@ class MessageReceiver:
|
||||
raise ValueError(f'Expected callable signature of {expectedsig};'
|
||||
f' got {sig.args}')
|
||||
|
||||
# Make sure we are only given async methods if we are an async handler
|
||||
# and sync ones otherwise.
|
||||
is_async = inspect.iscoroutinefunction(call)
|
||||
if self.is_async != is_async:
|
||||
msg = ('Expected a sync method; found an async one.' if is_async
|
||||
else 'Expected an async method; found a sync one.')
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check annotation types to determine what message types we handle.
|
||||
# Return-type annotation can be a Union, but we probably don't
|
||||
# have it available at runtime. Explicitly pull it in.
|
||||
anns = get_type_hints(call, localns={'Union': Union})
|
||||
|
||||
msgtype = anns.get('msg')
|
||||
if not isinstance(msgtype, type):
|
||||
raise TypeError(
|
||||
@ -642,54 +703,71 @@ class MessageReceiver:
|
||||
logging.warning(msg)
|
||||
raise TypeError(msg)
|
||||
|
||||
def _decode_incoming_message(self,
|
||||
msg: bytes) -> Tuple[Message, Type[Message]]:
|
||||
# Decode the incoming message.
|
||||
msg_decoded = self._protocol.decode_message(msg)
|
||||
msgtype = type(msg_decoded)
|
||||
assert issubclass(msgtype, Message)
|
||||
return msg_decoded, msgtype
|
||||
|
||||
def _encode_response(self, response: Optional[Response],
|
||||
msgtype: Type[Message]) -> bytes:
|
||||
|
||||
# A return value of None equals EmptyResponse.
|
||||
if response is None:
|
||||
response = EmptyResponse()
|
||||
|
||||
# Re-encode the response.
|
||||
assert isinstance(response, Response)
|
||||
# (user should never explicitly return these)
|
||||
assert not isinstance(response, ErrorResponse)
|
||||
assert type(response) in msgtype.get_response_types()
|
||||
return self._protocol.encode_response(response)
|
||||
|
||||
def _handle_raw_message_error(self, exc: Exception) -> bytes:
|
||||
if self._protocol.log_remote_exceptions:
|
||||
logging.exception('Error handling message.')
|
||||
|
||||
# If anything goes wrong, return a ErrorResponse instead.
|
||||
if (isinstance(exc, CleanError)
|
||||
and self._protocol.preserve_clean_errors):
|
||||
err_response = ErrorResponse(error_message=str(exc),
|
||||
error_type=ErrorType.CLEAN)
|
||||
else:
|
||||
err_response = ErrorResponse(
|
||||
error_message=(traceback.format_exc()
|
||||
if self._protocol.trusted_sender else
|
||||
'An unknown error has occurred.'),
|
||||
error_type=ErrorType.OTHER)
|
||||
return self._protocol.encode_response(err_response)
|
||||
|
||||
def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes:
|
||||
"""Decode, handle, and return an encoded response for a message."""
|
||||
try:
|
||||
# Decode the incoming message.
|
||||
msg_decoded = self._protocol.decode_message(msg)
|
||||
msgtype = type(msg_decoded)
|
||||
assert issubclass(msgtype, Message)
|
||||
|
||||
# Call the proper handler.
|
||||
msg_decoded, msgtype = self._decode_incoming_message(msg)
|
||||
handler = self._handlers.get(msgtype)
|
||||
if handler is None:
|
||||
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
|
||||
response = handler(bound_obj, msg_decoded)
|
||||
|
||||
# A return value of None equals EmptyResponse.
|
||||
if response is None:
|
||||
response = EmptyResponse()
|
||||
|
||||
# Re-encode the response.
|
||||
assert isinstance(response, Response)
|
||||
# (user should never explicitly return these)
|
||||
assert not isinstance(response, ErrorResponse)
|
||||
assert type(response) in msgtype.get_response_types()
|
||||
return self._protocol.encode_response(response)
|
||||
return self._encode_response(response, msgtype)
|
||||
|
||||
except Exception as exc:
|
||||
return self._handle_raw_message_error(exc)
|
||||
|
||||
if self._protocol.log_remote_exceptions:
|
||||
logging.exception('Error handling message.')
|
||||
|
||||
# If anything goes wrong, return a ErrorResponse instead.
|
||||
if (isinstance(exc, CleanError)
|
||||
and self._protocol.preserve_clean_errors):
|
||||
err_response = ErrorResponse(error_message=str(exc),
|
||||
error_type=ErrorType.CLEAN)
|
||||
|
||||
else:
|
||||
|
||||
err_response = ErrorResponse(
|
||||
error_message=(traceback.format_exc()
|
||||
if self._protocol.trusted_sender else
|
||||
'An unknown error has occurred.'),
|
||||
error_type=ErrorType.OTHER)
|
||||
return self._protocol.encode_response(err_response)
|
||||
|
||||
async def handle_raw_message_async(self, msg: bytes) -> bytes:
|
||||
async def handle_raw_message_async(self, bound_obj: Any,
|
||||
msg: bytes) -> bytes:
|
||||
"""Should be called when the receiver gets a message.
|
||||
|
||||
The return value is the raw response to the message.
|
||||
"""
|
||||
raise RuntimeError('Unimplemented!')
|
||||
try:
|
||||
msg_decoded, msgtype = self._decode_incoming_message(msg)
|
||||
handler = self._handlers.get(msgtype)
|
||||
if handler is None:
|
||||
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
|
||||
response = await handler(bound_obj, msg_decoded)
|
||||
return self._encode_response(response, msgtype)
|
||||
|
||||
except Exception as exc:
|
||||
return self._handle_raw_message_error(exc)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user