mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-01-26 17:03:14 +08:00
module generation for new messaging stuff
This commit is contained in:
parent
14c1a20ad0
commit
95bbb89d14
12
.idea/dictionaries/ericf.xml
generated
12
.idea/dictionaries/ericf.xml
generated
@ -361,6 +361,7 @@
|
||||
<w>chromebooks</w>
|
||||
<w>chunksize</w>
|
||||
<w>cjkcodecs</w>
|
||||
<w>classline</w>
|
||||
<w>classmethod</w>
|
||||
<w>classmethods</w>
|
||||
<w>classname</w>
|
||||
@ -1060,6 +1061,7 @@
|
||||
<w>imgh</w>
|
||||
<w>imghdr</w>
|
||||
<w>imgw</w>
|
||||
<w>importlines</w>
|
||||
<w>incentivized</w>
|
||||
<w>includetest</w>
|
||||
<w>incmd</w>
|
||||
@ -1135,6 +1137,7 @@
|
||||
<w>jisx</w>
|
||||
<w>jite</w>
|
||||
<w>jittering</w>
|
||||
<w>jnames</w>
|
||||
<w>joedeshon</w>
|
||||
<w>johab</w>
|
||||
<w>joinable</w>
|
||||
@ -1406,6 +1409,8 @@
|
||||
<w>msgdict</w>
|
||||
<w>msgfull</w>
|
||||
<w>msgtype</w>
|
||||
<w>msgtypes</w>
|
||||
<w>msgtypevar</w>
|
||||
<w>mshell</w>
|
||||
<w>msvccompiler</w>
|
||||
<w>msvcp</w>
|
||||
@ -1414,6 +1419,7 @@
|
||||
<w>mtrans</w>
|
||||
<w>mtvos</w>
|
||||
<w>mtype</w>
|
||||
<w>mtypenames</w>
|
||||
<w>mult</w>
|
||||
<w>multibytecodec</w>
|
||||
<w>multikillcount</w>
|
||||
@ -1570,6 +1576,7 @@
|
||||
<w>osval</w>
|
||||
<w>otherplayer</w>
|
||||
<w>otherspawn</w>
|
||||
<w>ourcode</w>
|
||||
<w>ourhash</w>
|
||||
<w>ourname</w>
|
||||
<w>ourself</w>
|
||||
@ -1699,6 +1706,7 @@
|
||||
<w>poweruptype</w>
|
||||
<w>powervr</w>
|
||||
<w>ppos</w>
|
||||
<w>ppre</w>
|
||||
<w>pproxy</w>
|
||||
<w>pptabcom</w>
|
||||
<w>pragmas</w>
|
||||
@ -1886,6 +1894,7 @@
|
||||
<w>respawnicon</w>
|
||||
<w>responsetype</w>
|
||||
<w>responsetypes</w>
|
||||
<w>responsetypevar</w>
|
||||
<w>resultstr</w>
|
||||
<w>retrysecs</w>
|
||||
<w>returncode</w>
|
||||
@ -1921,6 +1930,7 @@
|
||||
<w>rtnetlink</w>
|
||||
<w>rtxt</w>
|
||||
<w>rtypes</w>
|
||||
<w>rtypevar</w>
|
||||
<w>runmypy</w>
|
||||
<w>runonly</w>
|
||||
<w>runpy</w>
|
||||
@ -2053,6 +2063,7 @@
|
||||
<w>smag</w>
|
||||
<w>smallscale</w>
|
||||
<w>smlh</w>
|
||||
<w>smod</w>
|
||||
<w>smoothstep</w>
|
||||
<w>smoothy</w>
|
||||
<w>smtpd</w>
|
||||
@ -2333,6 +2344,7 @@
|
||||
<w>touchpad</w>
|
||||
<w>tournamententry</w>
|
||||
<w>tournamentscores</w>
|
||||
<w>tpimports</w>
|
||||
<w>tplayer</w>
|
||||
<w>tpos</w>
|
||||
<w>tproxy</w>
|
||||
|
||||
12
ballisticacore-cmake/.idea/dictionaries/ericf.xml
generated
12
ballisticacore-cmake/.idea/dictionaries/ericf.xml
generated
@ -179,6 +179,7 @@
|
||||
<w>chunksize</w>
|
||||
<w>cjief</w>
|
||||
<w>classdict</w>
|
||||
<w>classline</w>
|
||||
<w>cleanupcheck</w>
|
||||
<w>clientid</w>
|
||||
<w>clientinfo</w>
|
||||
@ -493,6 +494,7 @@
|
||||
<w>illum</w>
|
||||
<w>ilock</w>
|
||||
<w>imagewidget</w>
|
||||
<w>importlines</w>
|
||||
<w>incentivized</w>
|
||||
<w>inet</w>
|
||||
<w>infotxt</w>
|
||||
@ -533,6 +535,7 @@
|
||||
<w>jaxis</w>
|
||||
<w>jcjwf</w>
|
||||
<w>jmessage</w>
|
||||
<w>jnames</w>
|
||||
<w>keepalives</w>
|
||||
<w>keyanntype</w>
|
||||
<w>keycode</w>
|
||||
@ -643,6 +646,9 @@
|
||||
<w>msgdict</w>
|
||||
<w>msgfull</w>
|
||||
<w>msgtype</w>
|
||||
<w>msgtypes</w>
|
||||
<w>msgtypevar</w>
|
||||
<w>mtypenames</w>
|
||||
<w>mult</w>
|
||||
<w>multing</w>
|
||||
<w>multipass</w>
|
||||
@ -746,6 +752,7 @@
|
||||
<w>osis</w>
|
||||
<w>osssssssssss</w>
|
||||
<w>ostype</w>
|
||||
<w>ourcode</w>
|
||||
<w>ourname</w>
|
||||
<w>ourself</w>
|
||||
<w>ourstanding</w>
|
||||
@ -781,6 +788,7 @@
|
||||
<w>postinit</w>
|
||||
<w>postrun</w>
|
||||
<w>powerup</w>
|
||||
<w>ppre</w>
|
||||
<w>pptabcom</w>
|
||||
<w>precalc</w>
|
||||
<w>predeclare</w>
|
||||
@ -872,6 +880,7 @@
|
||||
<w>resetbtn</w>
|
||||
<w>resetinput</w>
|
||||
<w>responsetypes</w>
|
||||
<w>responsetypevar</w>
|
||||
<w>resync</w>
|
||||
<w>retrysecs</w>
|
||||
<w>retval</w>
|
||||
@ -889,6 +898,7 @@
|
||||
<w>rscode</w>
|
||||
<w>rsgc</w>
|
||||
<w>rtypes</w>
|
||||
<w>rtypevar</w>
|
||||
<w>runnables</w>
|
||||
<w>rvec</w>
|
||||
<w>rvel</w>
|
||||
@ -945,6 +955,7 @@
|
||||
<w>simpletype</w>
|
||||
<w>sisssssssss</w>
|
||||
<w>sixteenbits</w>
|
||||
<w>smod</w>
|
||||
<w>smoothering</w>
|
||||
<w>smoothstep</w>
|
||||
<w>smoothy</w>
|
||||
@ -1053,6 +1064,7 @@
|
||||
<w>touchpad</w>
|
||||
<w>toucs</w>
|
||||
<w>toutf</w>
|
||||
<w>tpimports</w>
|
||||
<w>tracebacks</w>
|
||||
<w>tracestr</w>
|
||||
<w>trackpad</w>
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
|
||||
<h4><em>last updated on 2021-09-07 for Ballistica version 1.6.5 build 20391</em></h4>
|
||||
<h4><em>last updated on 2021-09-08 for Ballistica version 1.6.5 build 20391</em></h4>
|
||||
<p>This page documents the Python classes and functions in the 'ba' module,
|
||||
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
|
||||
<hr>
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, overload
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -11,7 +12,7 @@ import pytest
|
||||
|
||||
from efro.error import CleanError, RemoteError
|
||||
from efro.dataclassio import ioprepped
|
||||
from efro.message import (Message, MessageProtocol, MessageSender,
|
||||
from efro.message import (Message, Response, MessageProtocol, MessageSender,
|
||||
MessageReceiver)
|
||||
from efrotools.statictest import static_type_equals
|
||||
|
||||
@ -21,52 +22,52 @@ if TYPE_CHECKING:
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TestMessage1(Message):
|
||||
class _TMessage1(Message):
|
||||
"""Just testing."""
|
||||
ival: int
|
||||
|
||||
@classmethod
|
||||
def get_response_types(cls) -> List[Type[Message]]:
|
||||
return [_TestMessageR1]
|
||||
def get_response_types(cls) -> List[Type[Response]]:
|
||||
return [_TResponse1]
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TestMessage2(Message):
|
||||
class _TMessage2(Message):
|
||||
"""Just testing."""
|
||||
sval: str
|
||||
|
||||
@classmethod
|
||||
def get_response_types(cls) -> List[Type[Message]]:
|
||||
return [_TestMessageR1, _TestMessageR2]
|
||||
def get_response_types(cls) -> List[Type[Response]]:
|
||||
return [_TResponse1, _TResponse2]
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TestMessageR1(Message):
|
||||
class _TResponse1(Response):
|
||||
"""Just testing."""
|
||||
bval: bool
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TestMessageR2(Message):
|
||||
class _TResponse2(Response):
|
||||
"""Just testing."""
|
||||
fval: float
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class _TestMessageR3(Message):
|
||||
class _TResponse3(Message):
|
||||
"""Just testing."""
|
||||
fval: float
|
||||
|
||||
|
||||
class _TestMessageSender(MessageSender):
|
||||
"""Testing type overrides for message sending.
|
||||
# SEND_CODE_TEST_BEGIN
|
||||
|
||||
Normally this would be auto-generated based on the protocol.
|
||||
"""
|
||||
|
||||
class _TestMessageSender(MessageSender):
|
||||
"""Protocol-specific sender."""
|
||||
|
||||
def __get__(self,
|
||||
obj: Any,
|
||||
@ -75,10 +76,7 @@ class _TestMessageSender(MessageSender):
|
||||
|
||||
|
||||
class _BoundTestMessageSender:
|
||||
"""Testing type overrides for message sending.
|
||||
|
||||
Normally this would be auto-generated based on the protocol.
|
||||
"""
|
||||
"""Protocol-specific bound sender."""
|
||||
|
||||
def __init__(self, obj: Any, sender: _TestMessageSender) -> None:
|
||||
assert obj is not None
|
||||
@ -86,56 +84,60 @@ class _BoundTestMessageSender:
|
||||
self._sender = sender
|
||||
|
||||
@overload
|
||||
def send(self, message: _TestMessage1) -> _TestMessageR1:
|
||||
def send(self, message: _TMessage1) -> _TResponse1:
|
||||
...
|
||||
|
||||
@overload
|
||||
def send(self,
|
||||
message: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
|
||||
def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]:
|
||||
...
|
||||
|
||||
def send(self, message: Message) -> Message:
|
||||
"""Send a particular message type."""
|
||||
def send(self, message: Message) -> Response:
|
||||
"""Send a message."""
|
||||
return self._sender.send(self._obj, message)
|
||||
|
||||
|
||||
# SEND_CODE_TEST_END
|
||||
# RCV_CODE_TEST_BEGIN
|
||||
|
||||
|
||||
class _TestMessageReceiver(MessageReceiver):
|
||||
"""Testing type overrides for message receiving.
|
||||
"""Protocol-specific receiver."""
|
||||
|
||||
Normally this would be auto-generated based on the protocol.
|
||||
"""
|
||||
|
||||
def __get__(self,
|
||||
obj: Any,
|
||||
type_in: Any = None) -> _BoundTestMessageReceiver:
|
||||
def __get__(
|
||||
self,
|
||||
obj: Any,
|
||||
type_in: Any = None,
|
||||
) -> _BoundTestMessageReceiver:
|
||||
return _BoundTestMessageReceiver(obj, self)
|
||||
|
||||
@overload
|
||||
def handler(
|
||||
self, call: Callable[[Any, _TestMessage1], _TestMessageR1]
|
||||
) -> Callable[[Any, _TestMessage1], _TestMessageR1]:
|
||||
self,
|
||||
call: Callable[[Any, _TMessage1], _TResponse1],
|
||||
) -> Callable[[Any, _TMessage1], _TResponse1]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def handler(
|
||||
self, call: Callable[[Any, _TestMessage2], Union[_TestMessageR1,
|
||||
_TestMessageR2]]
|
||||
) -> Callable[[Any, _TestMessage2], Union[_TestMessageR1, _TestMessageR2]]:
|
||||
self,
|
||||
call: Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]],
|
||||
) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]:
|
||||
...
|
||||
|
||||
def handler(self, call: Callable) -> Callable:
|
||||
"""Decorator to register a handler for a particular message type."""
|
||||
"""Decorator to register message handlers."""
|
||||
self.register_handler(call)
|
||||
return call
|
||||
|
||||
|
||||
class _BoundTestMessageReceiver:
|
||||
"""Testing type overrides for message receiving.
|
||||
"""Protocol-specific bound receiver."""
|
||||
|
||||
Normally this would be auto-generated based on the protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, obj: Any, receiver: _TestMessageReceiver) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
obj: Any,
|
||||
receiver: _TestMessageReceiver,
|
||||
) -> None:
|
||||
assert obj is not None
|
||||
self._obj = obj
|
||||
self._receiver = receiver
|
||||
@ -145,12 +147,14 @@ class _BoundTestMessageReceiver:
|
||||
return self._receiver.handle_raw_message(self._obj, message)
|
||||
|
||||
|
||||
# RCV_CODE_TEST_END
|
||||
|
||||
TEST_PROTOCOL = MessageProtocol(
|
||||
message_types={
|
||||
1: _TestMessage1,
|
||||
2: _TestMessage2,
|
||||
3: _TestMessageR1,
|
||||
4: _TestMessageR2,
|
||||
1: _TMessage1,
|
||||
2: _TMessage2,
|
||||
3: _TResponse1,
|
||||
4: _TResponse2,
|
||||
},
|
||||
trusted_client=True,
|
||||
log_remote_exceptions=False,
|
||||
@ -160,20 +164,61 @@ TEST_PROTOCOL = MessageProtocol(
|
||||
def test_protocol_creation() -> None:
|
||||
"""Test protocol creation."""
|
||||
|
||||
# This should fail because _TestMessage1 can return _TestMessageR1 which
|
||||
# This should fail because _TMessage1 can return _TResponse1 which
|
||||
# is not given an id here.
|
||||
with pytest.raises(ValueError):
|
||||
_protocol = MessageProtocol(message_types={1: _TestMessage1})
|
||||
_protocol = MessageProtocol(message_types={1: _TMessage1})
|
||||
|
||||
# Now it should work.
|
||||
_protocol = MessageProtocol(message_types={
|
||||
1: _TestMessage1,
|
||||
2: _TestMessageR1
|
||||
})
|
||||
_protocol = MessageProtocol(message_types={1: _TMessage1, 2: _TResponse1})
|
||||
|
||||
|
||||
def test_sender_module_creation() -> None:
|
||||
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||
smod = TEST_PROTOCOL.create_sender_module('Test', private=True)
|
||||
|
||||
# Clip everything up to our first class declaration.
|
||||
lines = smod.splitlines()
|
||||
classline = lines.index('class _TestMessageSender(MessageSender):')
|
||||
clipped = '\n'.join(lines[classline:])
|
||||
|
||||
# This snippet should match what we've got embedded above;
|
||||
# If not then we need to update our test code.
|
||||
with open(__file__, encoding='utf-8') as infile:
|
||||
ourcode = infile.read()
|
||||
|
||||
emb = f'# SEND_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# SEND_CODE_TEST_END\n'
|
||||
if emb not in ourcode:
|
||||
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
|
||||
raise RuntimeError('Generated sender module does not match embedded;'
|
||||
' test code needs to be updated.'
|
||||
' See test stdout for new code.')
|
||||
|
||||
|
||||
def test_receiver_module_creation() -> None:
|
||||
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||
smod = TEST_PROTOCOL.create_receiver_module('Test', private=True)
|
||||
|
||||
# Clip everything up to our first class declaration.
|
||||
lines = smod.splitlines()
|
||||
classline = lines.index('class _TestMessageReceiver(MessageReceiver):')
|
||||
clipped = '\n'.join(lines[classline:])
|
||||
|
||||
# This snippet should match what we've got embedded above;
|
||||
# If not then we need to update our test code.
|
||||
with open(__file__, encoding='utf-8') as infile:
|
||||
ourcode = infile.read()
|
||||
|
||||
emb = f'# RCV_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCV_CODE_TEST_END\n'
|
||||
if emb not in ourcode:
|
||||
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
|
||||
raise RuntimeError('Generated sender module does not match embedded;'
|
||||
' test code needs to be updated.'
|
||||
' See test stdout for new code.')
|
||||
|
||||
|
||||
def test_receiver_creation() -> None:
|
||||
"""Test receiver creation"""
|
||||
"""Test receiver creation."""
|
||||
|
||||
# This should fail due to the registered handler only specifying
|
||||
# one response message type while the message type itself
|
||||
@ -186,12 +231,10 @@ def test_receiver_creation() -> None:
|
||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
||||
|
||||
@receiver.handler
|
||||
def handle_test_message_2(self,
|
||||
msg: _TestMessage2) -> _TestMessageR2:
|
||||
def handle_test_message_2(self, msg: _TMessage2) -> _TResponse2:
|
||||
"""Test."""
|
||||
del msg # Unused
|
||||
print('Hello from test message 1 handler!')
|
||||
return _TestMessageR2(fval=1.2)
|
||||
return _TResponse2(fval=1.2)
|
||||
|
||||
# Should fail because not all message types in the protocol are handled.
|
||||
with pytest.raises(TypeError):
|
||||
@ -200,7 +243,7 @@ def test_receiver_creation() -> None:
|
||||
"""Test class incorporating receive functionality."""
|
||||
|
||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
||||
receiver.validate_handler_completeness()
|
||||
receiver.validate()
|
||||
|
||||
|
||||
def test_message_sending() -> None:
|
||||
@ -226,42 +269,40 @@ def test_message_sending() -> None:
|
||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
||||
|
||||
@receiver.handler
|
||||
def handle_test_message_1(self, msg: _TestMessage1) -> _TestMessageR1:
|
||||
def handle_test_message_1(self, msg: _TMessage1) -> _TResponse1:
|
||||
"""Test."""
|
||||
print('Hello from test message 1 handler!')
|
||||
if msg.ival == 1:
|
||||
raise CleanError('Testing Clean Error')
|
||||
if msg.ival == 2:
|
||||
raise RuntimeError('Testing Runtime Error')
|
||||
return _TestMessageR1(bval=True)
|
||||
return _TResponse1(bval=True)
|
||||
|
||||
@receiver.handler
|
||||
def handle_test_message_2(
|
||||
self,
|
||||
msg: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
|
||||
self, msg: _TMessage2) -> Union[_TResponse1, _TResponse2]:
|
||||
"""Test."""
|
||||
del msg # Unused
|
||||
print('Hello from test message 2 handler!')
|
||||
return _TestMessageR2(fval=1.2)
|
||||
return _TResponse2(fval=1.2)
|
||||
|
||||
receiver.validate_handler_completeness()
|
||||
receiver.validate()
|
||||
|
||||
obj_r = TestClassR()
|
||||
obj_s = TestClassS(target=obj_r)
|
||||
|
||||
response = obj_s.msg.send(_TestMessage1(ival=0))
|
||||
response2 = obj_s.msg.send(_TestMessage2(sval='rah'))
|
||||
assert static_type_equals(response, _TestMessageR1)
|
||||
assert isinstance(response, _TestMessageR1)
|
||||
assert isinstance(response2, (_TestMessageR1, _TestMessageR2))
|
||||
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
|
||||
response = obj_s.msg.send(_TMessage1(ival=0))
|
||||
response2 = obj_s.msg.send(_TMessage2(sval='rah'))
|
||||
assert static_type_equals(response, _TResponse1)
|
||||
assert isinstance(response, _TResponse1)
|
||||
assert isinstance(response2, (_TResponse1, _TResponse2))
|
||||
|
||||
# Remote CleanErrors should come across locally as the same.
|
||||
try:
|
||||
_response3 = obj_s.msg.send(_TestMessage1(ival=1))
|
||||
_response3 = obj_s.msg.send(_TMessage1(ival=1))
|
||||
except Exception as exc:
|
||||
assert isinstance(exc, CleanError)
|
||||
assert str(exc) == 'Testing Clean Error'
|
||||
|
||||
# Other remote errors should come across as RemoteError.
|
||||
with pytest.raises(RemoteError):
|
||||
_response4 = obj_s.msg.send(_TestMessage1(ival=2))
|
||||
_response4 = obj_s.msg.send(_TMessage1(ival=2))
|
||||
|
||||
@ -22,7 +22,7 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
|
||||
Sequence)
|
||||
Sequence, Union)
|
||||
from efro.error import CommunicationError
|
||||
|
||||
TM = TypeVar('TM', bound='MessageSender')
|
||||
@ -35,21 +35,24 @@ class RemoteErrorType(Enum):
|
||||
|
||||
|
||||
class Message:
|
||||
"""Base class for messages and their responses."""
|
||||
"""Base class for messages."""
|
||||
|
||||
@classmethod
|
||||
def get_response_types(cls) -> List[Type[Message]]:
|
||||
def get_response_types(cls) -> List[Type[Response]]:
|
||||
"""Return all message types this Message can result in when sent.
|
||||
Messages intended only for response types can leave this empty.
|
||||
Note: RemoteErrorMessage is handled transparently and does not
|
||||
need to be specified here.
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
class Response:
|
||||
"""Base class for responses to messages."""
|
||||
|
||||
|
||||
@ioprepped
|
||||
@dataclass
|
||||
class RemoteErrorMessage(Message):
|
||||
class RemoteErrorMessage(Response):
|
||||
"""Message saying some error has occurred on the other end."""
|
||||
error_message: Annotated[str, IOAttrs('m')]
|
||||
error_type: Annotated[RemoteErrorType, IOAttrs('t')]
|
||||
@ -64,14 +67,13 @@ class MessageProtocol:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
message_types: Dict[int, Type[Message]],
|
||||
message_types: Dict[int, Union[Type[Message],
|
||||
Type[Response]]],
|
||||
type_key: Optional[str] = None,
|
||||
preserve_clean_errors: bool = True,
|
||||
log_remote_exceptions: bool = True,
|
||||
trusted_client: bool = False) -> None:
|
||||
"""Create a protocol with a given configuration.
|
||||
Each entry for message_types should contain an ID, a message type,
|
||||
and all possible response types.
|
||||
|
||||
If 'type_key' is provided, the message type ID is stored as the
|
||||
provided key in the message dict; otherwise it will be stored as
|
||||
@ -86,8 +88,10 @@ class MessageProtocol:
|
||||
be included in the RemoteError. This should only be enabled in cases
|
||||
where the client is trusted.
|
||||
"""
|
||||
self.message_types_by_id: Dict[int, Type[Message]] = {}
|
||||
self.message_ids_by_type: Dict[Type[Message], int] = {}
|
||||
self.message_types_by_id: Dict[int, Union[Type[Message],
|
||||
Type[Response]]] = {}
|
||||
self.message_ids_by_type: Dict[Union[Type[Message], Type[Response]],
|
||||
int] = {}
|
||||
for m_id, m_type in message_types.items():
|
||||
|
||||
# Make sure only valid message types were passed and each
|
||||
@ -95,34 +99,47 @@ class MessageProtocol:
|
||||
assert isinstance(m_id, int)
|
||||
assert m_id >= 0
|
||||
assert (is_ioprepped_dataclass(m_type)
|
||||
and issubclass(m_type, Message))
|
||||
and issubclass(m_type, (Message, Response)))
|
||||
assert self.message_types_by_id.get(m_id) is None
|
||||
|
||||
self.message_types_by_id[m_id] = m_type
|
||||
self.message_ids_by_type[m_type] = m_id
|
||||
|
||||
# Make sure all return types are valid and have been assigned
|
||||
# an ID as well.
|
||||
# Some extra-thorough validation in debug mode.
|
||||
if __debug__:
|
||||
all_response_types: Set[Type[Message]] = set()
|
||||
# Make sure all return types are valid and have been assigned
|
||||
# an ID as well.
|
||||
all_response_types: Set[Type[Response]] = set()
|
||||
for m_id, m_type in message_types.items():
|
||||
m_rtypes = m_type.get_response_types()
|
||||
assert isinstance(m_rtypes, list)
|
||||
assert len(set(m_rtypes)) == len(m_rtypes) # check for dups
|
||||
all_response_types.update(m_rtypes)
|
||||
if issubclass(m_type, Message):
|
||||
m_rtypes = m_type.get_response_types()
|
||||
assert isinstance(m_rtypes, list)
|
||||
assert m_rtypes # make sure not empty
|
||||
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
|
||||
all_response_types.update(m_rtypes)
|
||||
for cls in all_response_types:
|
||||
assert is_ioprepped_dataclass(cls) and issubclass(cls, Message)
|
||||
assert is_ioprepped_dataclass(cls) and issubclass(
|
||||
cls, (Message, Response))
|
||||
if cls not in self.message_ids_by_type:
|
||||
raise ValueError(f'Possible response type {cls}'
|
||||
f' was not included in message_types.')
|
||||
|
||||
# Make sure all registered types have unique base names.
|
||||
# We can take advantage of this to generate cleaner looking
|
||||
# protocol modules. Can revisit if this is ever a problem.
|
||||
mtypenames = set(tp.__name__ for tp in self.message_ids_by_type)
|
||||
if len(mtypenames) != len(message_types):
|
||||
raise ValueError(
|
||||
'message_types contains duplicate __name__s;'
|
||||
' all types are required to have unique names.')
|
||||
|
||||
self._type_key = type_key
|
||||
self.preserve_clean_errors = preserve_clean_errors
|
||||
self.log_remote_exceptions = log_remote_exceptions
|
||||
self.trusted_client = trusted_client
|
||||
|
||||
def message_encode(self,
|
||||
message: Message,
|
||||
message: Union[Message, Response],
|
||||
is_error: bool = False) -> bytes:
|
||||
"""Encode a message to bytes for transport."""
|
||||
|
||||
@ -132,8 +149,9 @@ class MessageProtocol:
|
||||
else:
|
||||
m_id = self.message_ids_by_type.get(type(message))
|
||||
if m_id is None:
|
||||
raise TypeError(f'Message type is not registered in Protocol:'
|
||||
f' {type(message)}')
|
||||
raise TypeError(
|
||||
f'Message/Response type is not registered in Protocol:'
|
||||
f' {type(message)}')
|
||||
msgdict = dataclass_to_dict(message)
|
||||
|
||||
# Encode type as part of the message dict if desired
|
||||
@ -148,7 +166,7 @@ class MessageProtocol:
|
||||
out = {'m': msgdict, 't': m_id}
|
||||
return json.dumps(out, separators=(',', ':')).encode()
|
||||
|
||||
def message_decode(self, data: bytes) -> Message:
|
||||
def message_decode(self, data: bytes) -> Union[Message, Response]:
|
||||
"""Decode a message from bytes.
|
||||
|
||||
If the message represents a remote error, an Exception will
|
||||
@ -178,25 +196,200 @@ class MessageProtocol:
|
||||
# Decode this particular type and make sure its valid.
|
||||
msgtype = self.message_types_by_id.get(m_id)
|
||||
if msgtype is None:
|
||||
raise TypeError(f'Got unregistered message type id of {m_id}.')
|
||||
raise TypeError(
|
||||
f'Got unregistered message/response type id of {m_id}.')
|
||||
|
||||
return dataclass_from_dict(msgtype, msgdict)
|
||||
out = dataclass_from_dict(msgtype, msgdict)
|
||||
assert isinstance(out, (Message, Response))
|
||||
return out
|
||||
|
||||
def create_sender_module(self, classname: str) -> str:
|
||||
def _get_module_header(self, part: str) -> str:
|
||||
"""Return common parts of generated modules."""
|
||||
imports: Dict[str, List[str]] = {}
|
||||
for msgtype in self.message_ids_by_type:
|
||||
imports.setdefault(msgtype.__module__, []).append(msgtype.__name__)
|
||||
importlines = ''
|
||||
for module, names in sorted(imports.items()):
|
||||
jnames = ', '.join(names)
|
||||
line = f'from {module} import {jnames}'
|
||||
if len(line) > 79:
|
||||
# Recreate in a wrapping-friendly form.
|
||||
line = f'from {module} import ({jnames})'
|
||||
importlines += f'{line}\n'
|
||||
|
||||
if part == 'sender':
|
||||
importlines = (
|
||||
f'from efro.message import MessageSender\n{importlines}')
|
||||
tpimports = 'from efro.message import Message, Response'
|
||||
else:
|
||||
importlines = (
|
||||
f'from efro.message import MessageReceiver\n{importlines}')
|
||||
tpimports = 'from efro.message import Message, Response'
|
||||
|
||||
out = ('# Released under the MIT License. See LICENSE for details.\n'
|
||||
f'#\n'
|
||||
f'"""Auto-generated {part} module."""\n'
|
||||
f'\n'
|
||||
f'from __future__ import annotations\n'
|
||||
f'\n'
|
||||
f'from typing import TYPE_CHECKING, overload\n'
|
||||
f'\n'
|
||||
f'{importlines}'
|
||||
f'\n'
|
||||
f'if TYPE_CHECKING:\n'
|
||||
f' from typing import Union\n'
|
||||
f' {tpimports}\n'
|
||||
f'\n'
|
||||
f'\n')
|
||||
return out
|
||||
|
||||
def create_sender_module(self,
|
||||
classname: str,
|
||||
private: bool = False) -> str:
|
||||
""""Create a Python module defining a MessageSender subclass.
|
||||
|
||||
This class is primarily for type checking and will contain overrides
|
||||
for the varieties of send calls for message/response types defined
|
||||
in the protocol.
|
||||
|
||||
Note that line lengths are not clipped, so output may need to be
|
||||
run through a formatter to prevent lint warnings about excessive
|
||||
line lengths.
|
||||
"""
|
||||
|
||||
def create_receiver_module(self, classname: str) -> str:
|
||||
ppre = '_' if private else ''
|
||||
out = self._get_module_header('sender')
|
||||
out += (f'class {ppre}{classname}MessageSender(MessageSender):\n'
|
||||
f' """Protocol-specific sender."""\n'
|
||||
f'\n'
|
||||
f' def __get__(self,\n'
|
||||
f' obj: Any,\n'
|
||||
f' type_in: Any = None)'
|
||||
f' -> {ppre}Bound{classname}MessageSender:\n'
|
||||
f' return {ppre}Bound{classname}MessageSender'
|
||||
f'(obj, self)\n'
|
||||
f'\n'
|
||||
f'\n'
|
||||
f'class {ppre}Bound{classname}MessageSender:\n'
|
||||
f' """Protocol-specific bound sender."""\n'
|
||||
f'\n'
|
||||
f' def __init__(self, obj: Any,'
|
||||
f' sender: {ppre}{classname}MessageSender) -> None:\n'
|
||||
f' assert obj is not None\n'
|
||||
f' self._obj = obj\n'
|
||||
f' self._sender = sender\n')
|
||||
|
||||
# Define handler() overloads for all registered message types.
|
||||
msgtypes = [
|
||||
t for t in self.message_ids_by_type if issubclass(t, Message)
|
||||
]
|
||||
|
||||
# Ew; @overload requires at least 2 different signatures so
|
||||
# we need to simply write a single function if we have < 2.
|
||||
if len(msgtypes) == 1:
|
||||
raise RuntimeError('FIXME: currently require at least 2'
|
||||
' message types.')
|
||||
if len(msgtypes) > 1:
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
tps = ', '.join(t.__name__ for t in rtypes)
|
||||
responsetypevar = f'Union[{tps}]'
|
||||
else:
|
||||
responsetypevar = rtypes[0].__name__
|
||||
out += (f'\n'
|
||||
f' @overload\n'
|
||||
f' def send(self, message: {msgtypevar})'
|
||||
f' -> {responsetypevar}:\n'
|
||||
f' ...\n')
|
||||
out += ('\n'
|
||||
' def send(self, message: Message) -> Response:\n'
|
||||
' """Send a message."""\n'
|
||||
' return self._sender.send(self._obj, message)\n')
|
||||
|
||||
return out
|
||||
|
||||
def create_receiver_module(self,
|
||||
classname: str,
|
||||
private: bool = False) -> str:
|
||||
""""Create a Python module defining a MessageReceiver subclass.
|
||||
|
||||
This class is primarily for type checking and will contain overrides
|
||||
for the register method for message/response types defined in
|
||||
the protocol.
|
||||
|
||||
Note that line lengths are not clipped, so output may need to be
|
||||
run through a formatter to prevent lint warnings about excessive
|
||||
line lengths.
|
||||
"""
|
||||
ppre = '_' if private else ''
|
||||
out = self._get_module_header('receiver')
|
||||
out += (f'class {ppre}{classname}MessageReceiver(MessageReceiver):\n'
|
||||
f' """Protocol-specific receiver."""\n'
|
||||
f'\n'
|
||||
f' def __get__(\n'
|
||||
f' self,\n'
|
||||
f' obj: Any,\n'
|
||||
f' type_in: Any = None,\n'
|
||||
f' ) -> {ppre}Bound{classname}MessageReceiver:\n'
|
||||
f' return {ppre}Bound{classname}MessageReceiver('
|
||||
f'obj, self)\n')
|
||||
|
||||
# Define handler() overloads for all registered message types.
|
||||
msgtypes = [
|
||||
t for t in self.message_ids_by_type if issubclass(t, Message)
|
||||
]
|
||||
|
||||
# Ew; @overload requires at least 2 different signatures so
|
||||
# we need to simply write a single function if we have < 2.
|
||||
if len(msgtypes) == 1:
|
||||
raise RuntimeError('FIXME: currently require at least 2'
|
||||
' message types.')
|
||||
if len(msgtypes) > 1:
|
||||
for msgtype in msgtypes:
|
||||
msgtypevar = msgtype.__name__
|
||||
rtypes = msgtype.get_response_types()
|
||||
if len(rtypes) > 1:
|
||||
tps = ', '.join(t.__name__ for t in rtypes)
|
||||
rtypevar = f'Union[{tps}]'
|
||||
else:
|
||||
rtypevar = rtypes[0].__name__
|
||||
out += (
|
||||
f'\n'
|
||||
f' @overload\n'
|
||||
f' def handler(\n'
|
||||
f' self,\n'
|
||||
f' call: Callable[[Any, {msgtypevar}], '
|
||||
f'{rtypevar}],\n'
|
||||
f' ) -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
|
||||
f' ...\n')
|
||||
out += ('\n'
|
||||
' def handler(self, call: Callable) -> Callable:\n'
|
||||
' """Decorator to register message handlers."""\n'
|
||||
' self.register_handler(call)\n'
|
||||
' return call\n')
|
||||
|
||||
out += (f'\n'
|
||||
f'\n'
|
||||
f'class {ppre}Bound{classname}MessageReceiver:\n'
|
||||
f' """Protocol-specific bound receiver."""\n'
|
||||
f'\n'
|
||||
f' def __init__(\n'
|
||||
f' self,\n'
|
||||
f' obj: Any,\n'
|
||||
f' receiver: _TestMessageReceiver,\n'
|
||||
f' ) -> None:\n'
|
||||
f' assert obj is not None\n'
|
||||
f' self._obj = obj\n'
|
||||
f' self._receiver = receiver\n'
|
||||
f'\n'
|
||||
f' def handle_raw_message(self, message: bytes) -> bytes:\n'
|
||||
f' """Handle a raw incoming message."""\n'
|
||||
f' return self._receiver.handle_raw_message'
|
||||
f'(self._obj, message)\n')
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MessageSender:
|
||||
@ -232,7 +425,7 @@ class MessageSender:
|
||||
self._send_raw_message_call = call
|
||||
return call
|
||||
|
||||
def send(self, bound_obj: Any, message: Message) -> Message:
|
||||
def send(self, bound_obj: Any, message: Message) -> Response:
|
||||
"""Send a message and receive a response.
|
||||
|
||||
Will encode the message for transport and call dispatch_raw_message()
|
||||
@ -240,13 +433,10 @@ class MessageSender:
|
||||
if self._send_raw_message_call is None:
|
||||
raise RuntimeError('send() is unimplemented for this type.')
|
||||
|
||||
# Only types with possible response types should ever be sent.
|
||||
assert type(message).get_response_types()
|
||||
|
||||
msg_encoded = self._protocol.message_encode(message)
|
||||
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
|
||||
response = self._protocol.message_decode(response_encoded)
|
||||
assert isinstance(response, Message)
|
||||
assert isinstance(response, Response)
|
||||
assert type(response) in type(message).get_response_types()
|
||||
return response
|
||||
|
||||
@ -278,10 +468,10 @@ class MessageReceiver:
|
||||
class MyClass:
|
||||
receiver = MyMessageReceiver()
|
||||
|
||||
# MyMessageReceiver should provide overloads to register_handler()
|
||||
# to ensure all registered handlers have valid types/return-types.
|
||||
# MyMessageReceiver fills out handler() overloads to ensure all
|
||||
# registered handlers have valid types/return-types.
|
||||
@receiver.handler
|
||||
def handle_some_message_type(self, message: SomeType) -> AnotherType:
|
||||
def handle_some_message_type(self, message: SomeMsg) -> SomeResponse:
|
||||
# Deal with this message type here.
|
||||
|
||||
# This will trigger the registered handler being called.
|
||||
@ -298,7 +488,7 @@ class MessageReceiver:
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def register_handler(self, call: Callable[[Any, Message],
|
||||
Message]) -> None:
|
||||
Response]) -> None:
|
||||
"""Register a handler call.
|
||||
|
||||
The message type handled by the call is determined by its
|
||||
@ -366,14 +556,10 @@ class MessageReceiver:
|
||||
# Ok; we're good!
|
||||
self._handlers[msgtype] = call
|
||||
|
||||
def validate_handler_completeness(self, warn_only: bool = False) -> None:
|
||||
"""Return whether this receiver handles all protocol messages.
|
||||
|
||||
Only messages having possible response types are considered, as
|
||||
those are the only ones that can be sent to a receiver.
|
||||
"""
|
||||
def validate(self, warn_only: bool = False) -> None:
|
||||
"""Check for handler completeness, valid types, etc."""
|
||||
for msgtype in self._protocol.message_ids_by_type.keys():
|
||||
if not msgtype.get_response_types():
|
||||
if issubclass(msgtype, Response):
|
||||
continue
|
||||
if msgtype not in self._handlers:
|
||||
msg = (f'Protocol message {msgtype} not handled'
|
||||
@ -388,6 +574,7 @@ class MessageReceiver:
|
||||
# Decode the incoming message.
|
||||
msg_decoded = self._protocol.message_decode(msg)
|
||||
msgtype = type(msg_decoded)
|
||||
assert issubclass(msgtype, Message)
|
||||
|
||||
# Call the proper handler.
|
||||
handler = self._handlers.get(msgtype)
|
||||
@ -396,7 +583,7 @@ class MessageReceiver:
|
||||
response = handler(bound_obj, msg_decoded)
|
||||
|
||||
# Re-encode the response.
|
||||
assert isinstance(response, Message)
|
||||
assert isinstance(response, Response)
|
||||
assert type(response) in msgtype.get_response_types()
|
||||
return self._protocol.message_encode(response)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user