module generation for new messaging stuff

This commit is contained in:
Eric Froemling 2021-09-08 12:00:14 -05:00
parent 14c1a20ad0
commit 95bbb89d14
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
5 changed files with 369 additions and 117 deletions

View File

@ -361,6 +361,7 @@
<w>chromebooks</w> <w>chromebooks</w>
<w>chunksize</w> <w>chunksize</w>
<w>cjkcodecs</w> <w>cjkcodecs</w>
<w>classline</w>
<w>classmethod</w> <w>classmethod</w>
<w>classmethods</w> <w>classmethods</w>
<w>classname</w> <w>classname</w>
@ -1060,6 +1061,7 @@
<w>imgh</w> <w>imgh</w>
<w>imghdr</w> <w>imghdr</w>
<w>imgw</w> <w>imgw</w>
<w>importlines</w>
<w>incentivized</w> <w>incentivized</w>
<w>includetest</w> <w>includetest</w>
<w>incmd</w> <w>incmd</w>
@ -1135,6 +1137,7 @@
<w>jisx</w> <w>jisx</w>
<w>jite</w> <w>jite</w>
<w>jittering</w> <w>jittering</w>
<w>jnames</w>
<w>joedeshon</w> <w>joedeshon</w>
<w>johab</w> <w>johab</w>
<w>joinable</w> <w>joinable</w>
@ -1406,6 +1409,8 @@
<w>msgdict</w> <w>msgdict</w>
<w>msgfull</w> <w>msgfull</w>
<w>msgtype</w> <w>msgtype</w>
<w>msgtypes</w>
<w>msgtypevar</w>
<w>mshell</w> <w>mshell</w>
<w>msvccompiler</w> <w>msvccompiler</w>
<w>msvcp</w> <w>msvcp</w>
@ -1414,6 +1419,7 @@
<w>mtrans</w> <w>mtrans</w>
<w>mtvos</w> <w>mtvos</w>
<w>mtype</w> <w>mtype</w>
<w>mtypenames</w>
<w>mult</w> <w>mult</w>
<w>multibytecodec</w> <w>multibytecodec</w>
<w>multikillcount</w> <w>multikillcount</w>
@ -1570,6 +1576,7 @@
<w>osval</w> <w>osval</w>
<w>otherplayer</w> <w>otherplayer</w>
<w>otherspawn</w> <w>otherspawn</w>
<w>ourcode</w>
<w>ourhash</w> <w>ourhash</w>
<w>ourname</w> <w>ourname</w>
<w>ourself</w> <w>ourself</w>
@ -1699,6 +1706,7 @@
<w>poweruptype</w> <w>poweruptype</w>
<w>powervr</w> <w>powervr</w>
<w>ppos</w> <w>ppos</w>
<w>ppre</w>
<w>pproxy</w> <w>pproxy</w>
<w>pptabcom</w> <w>pptabcom</w>
<w>pragmas</w> <w>pragmas</w>
@ -1886,6 +1894,7 @@
<w>respawnicon</w> <w>respawnicon</w>
<w>responsetype</w> <w>responsetype</w>
<w>responsetypes</w> <w>responsetypes</w>
<w>responsetypevar</w>
<w>resultstr</w> <w>resultstr</w>
<w>retrysecs</w> <w>retrysecs</w>
<w>returncode</w> <w>returncode</w>
@ -1921,6 +1930,7 @@
<w>rtnetlink</w> <w>rtnetlink</w>
<w>rtxt</w> <w>rtxt</w>
<w>rtypes</w> <w>rtypes</w>
<w>rtypevar</w>
<w>runmypy</w> <w>runmypy</w>
<w>runonly</w> <w>runonly</w>
<w>runpy</w> <w>runpy</w>
@ -2053,6 +2063,7 @@
<w>smag</w> <w>smag</w>
<w>smallscale</w> <w>smallscale</w>
<w>smlh</w> <w>smlh</w>
<w>smod</w>
<w>smoothstep</w> <w>smoothstep</w>
<w>smoothy</w> <w>smoothy</w>
<w>smtpd</w> <w>smtpd</w>
@ -2333,6 +2344,7 @@
<w>touchpad</w> <w>touchpad</w>
<w>tournamententry</w> <w>tournamententry</w>
<w>tournamentscores</w> <w>tournamentscores</w>
<w>tpimports</w>
<w>tplayer</w> <w>tplayer</w>
<w>tpos</w> <w>tpos</w>
<w>tproxy</w> <w>tproxy</w>

View File

@ -179,6 +179,7 @@
<w>chunksize</w> <w>chunksize</w>
<w>cjief</w> <w>cjief</w>
<w>classdict</w> <w>classdict</w>
<w>classline</w>
<w>cleanupcheck</w> <w>cleanupcheck</w>
<w>clientid</w> <w>clientid</w>
<w>clientinfo</w> <w>clientinfo</w>
@ -493,6 +494,7 @@
<w>illum</w> <w>illum</w>
<w>ilock</w> <w>ilock</w>
<w>imagewidget</w> <w>imagewidget</w>
<w>importlines</w>
<w>incentivized</w> <w>incentivized</w>
<w>inet</w> <w>inet</w>
<w>infotxt</w> <w>infotxt</w>
@ -533,6 +535,7 @@
<w>jaxis</w> <w>jaxis</w>
<w>jcjwf</w> <w>jcjwf</w>
<w>jmessage</w> <w>jmessage</w>
<w>jnames</w>
<w>keepalives</w> <w>keepalives</w>
<w>keyanntype</w> <w>keyanntype</w>
<w>keycode</w> <w>keycode</w>
@ -643,6 +646,9 @@
<w>msgdict</w> <w>msgdict</w>
<w>msgfull</w> <w>msgfull</w>
<w>msgtype</w> <w>msgtype</w>
<w>msgtypes</w>
<w>msgtypevar</w>
<w>mtypenames</w>
<w>mult</w> <w>mult</w>
<w>multing</w> <w>multing</w>
<w>multipass</w> <w>multipass</w>
@ -746,6 +752,7 @@
<w>osis</w> <w>osis</w>
<w>osssssssssss</w> <w>osssssssssss</w>
<w>ostype</w> <w>ostype</w>
<w>ourcode</w>
<w>ourname</w> <w>ourname</w>
<w>ourself</w> <w>ourself</w>
<w>ourstanding</w> <w>ourstanding</w>
@ -781,6 +788,7 @@
<w>postinit</w> <w>postinit</w>
<w>postrun</w> <w>postrun</w>
<w>powerup</w> <w>powerup</w>
<w>ppre</w>
<w>pptabcom</w> <w>pptabcom</w>
<w>precalc</w> <w>precalc</w>
<w>predeclare</w> <w>predeclare</w>
@ -872,6 +880,7 @@
<w>resetbtn</w> <w>resetbtn</w>
<w>resetinput</w> <w>resetinput</w>
<w>responsetypes</w> <w>responsetypes</w>
<w>responsetypevar</w>
<w>resync</w> <w>resync</w>
<w>retrysecs</w> <w>retrysecs</w>
<w>retval</w> <w>retval</w>
@ -889,6 +898,7 @@
<w>rscode</w> <w>rscode</w>
<w>rsgc</w> <w>rsgc</w>
<w>rtypes</w> <w>rtypes</w>
<w>rtypevar</w>
<w>runnables</w> <w>runnables</w>
<w>rvec</w> <w>rvec</w>
<w>rvel</w> <w>rvel</w>
@ -945,6 +955,7 @@
<w>simpletype</w> <w>simpletype</w>
<w>sisssssssss</w> <w>sisssssssss</w>
<w>sixteenbits</w> <w>sixteenbits</w>
<w>smod</w>
<w>smoothering</w> <w>smoothering</w>
<w>smoothstep</w> <w>smoothstep</w>
<w>smoothy</w> <w>smoothy</w>
@ -1053,6 +1064,7 @@
<w>touchpad</w> <w>touchpad</w>
<w>toucs</w> <w>toucs</w>
<w>toutf</w> <w>toutf</w>
<w>tpimports</w>
<w>tracebacks</w> <w>tracebacks</w>
<w>tracestr</w> <w>tracestr</w>
<w>trackpad</w> <w>trackpad</w>

View File

@ -1,5 +1,5 @@
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND --> <!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
<h4><em>last updated on 2021-09-07 for Ballistica version 1.6.5 build 20391</em></h4> <h4><em>last updated on 2021-09-08 for Ballistica version 1.6.5 build 20391</em></h4>
<p>This page documents the Python classes and functions in the 'ba' module, <p>This page documents the Python classes and functions in the 'ba' module,
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p> which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
<hr> <hr>

View File

@ -4,6 +4,7 @@
from __future__ import annotations from __future__ import annotations
import os
from typing import TYPE_CHECKING, overload from typing import TYPE_CHECKING, overload
from dataclasses import dataclass from dataclasses import dataclass
@ -11,7 +12,7 @@ import pytest
from efro.error import CleanError, RemoteError from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped from efro.dataclassio import ioprepped
from efro.message import (Message, MessageProtocol, MessageSender, from efro.message import (Message, Response, MessageProtocol, MessageSender,
MessageReceiver) MessageReceiver)
from efrotools.statictest import static_type_equals from efrotools.statictest import static_type_equals
@ -21,52 +22,52 @@ if TYPE_CHECKING:
@ioprepped @ioprepped
@dataclass @dataclass
class _TestMessage1(Message): class _TMessage1(Message):
"""Just testing.""" """Just testing."""
ival: int ival: int
@classmethod @classmethod
def get_response_types(cls) -> List[Type[Message]]: def get_response_types(cls) -> List[Type[Response]]:
return [_TestMessageR1] return [_TResponse1]
@ioprepped @ioprepped
@dataclass @dataclass
class _TestMessage2(Message): class _TMessage2(Message):
"""Just testing.""" """Just testing."""
sval: str sval: str
@classmethod @classmethod
def get_response_types(cls) -> List[Type[Message]]: def get_response_types(cls) -> List[Type[Response]]:
return [_TestMessageR1, _TestMessageR2] return [_TResponse1, _TResponse2]
@ioprepped @ioprepped
@dataclass @dataclass
class _TestMessageR1(Message): class _TResponse1(Response):
"""Just testing.""" """Just testing."""
bval: bool bval: bool
@ioprepped @ioprepped
@dataclass @dataclass
class _TestMessageR2(Message): class _TResponse2(Response):
"""Just testing.""" """Just testing."""
fval: float fval: float
@ioprepped @ioprepped
@dataclass @dataclass
class _TestMessageR3(Message): class _TResponse3(Message):
"""Just testing.""" """Just testing."""
fval: float fval: float
class _TestMessageSender(MessageSender): # SEND_CODE_TEST_BEGIN
"""Testing type overrides for message sending.
Normally this would be auto-generated based on the protocol.
""" class _TestMessageSender(MessageSender):
"""Protocol-specific sender."""
def __get__(self, def __get__(self,
obj: Any, obj: Any,
@ -75,10 +76,7 @@ class _TestMessageSender(MessageSender):
class _BoundTestMessageSender: class _BoundTestMessageSender:
"""Testing type overrides for message sending. """Protocol-specific bound sender."""
Normally this would be auto-generated based on the protocol.
"""
def __init__(self, obj: Any, sender: _TestMessageSender) -> None: def __init__(self, obj: Any, sender: _TestMessageSender) -> None:
assert obj is not None assert obj is not None
@ -86,56 +84,60 @@ class _BoundTestMessageSender:
self._sender = sender self._sender = sender
@overload @overload
def send(self, message: _TestMessage1) -> _TestMessageR1: def send(self, message: _TMessage1) -> _TResponse1:
... ...
@overload @overload
def send(self, def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]:
message: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
... ...
def send(self, message: Message) -> Message: def send(self, message: Message) -> Response:
"""Send a particular message type.""" """Send a message."""
return self._sender.send(self._obj, message) return self._sender.send(self._obj, message)
# SEND_CODE_TEST_END
# RCV_CODE_TEST_BEGIN
class _TestMessageReceiver(MessageReceiver): class _TestMessageReceiver(MessageReceiver):
"""Testing type overrides for message receiving. """Protocol-specific receiver."""
Normally this would be auto-generated based on the protocol. def __get__(
""" self,
obj: Any,
def __get__(self, type_in: Any = None,
obj: Any, ) -> _BoundTestMessageReceiver:
type_in: Any = None) -> _BoundTestMessageReceiver:
return _BoundTestMessageReceiver(obj, self) return _BoundTestMessageReceiver(obj, self)
@overload @overload
def handler( def handler(
self, call: Callable[[Any, _TestMessage1], _TestMessageR1] self,
) -> Callable[[Any, _TestMessage1], _TestMessageR1]: call: Callable[[Any, _TMessage1], _TResponse1],
) -> Callable[[Any, _TMessage1], _TResponse1]:
... ...
@overload @overload
def handler( def handler(
self, call: Callable[[Any, _TestMessage2], Union[_TestMessageR1, self,
_TestMessageR2]] call: Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]],
) -> Callable[[Any, _TestMessage2], Union[_TestMessageR1, _TestMessageR2]]: ) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]:
... ...
def handler(self, call: Callable) -> Callable: def handler(self, call: Callable) -> Callable:
"""Decorator to register a handler for a particular message type.""" """Decorator to register message handlers."""
self.register_handler(call) self.register_handler(call)
return call return call
class _BoundTestMessageReceiver: class _BoundTestMessageReceiver:
"""Testing type overrides for message receiving. """Protocol-specific bound receiver."""
Normally this would be auto-generated based on the protocol. def __init__(
""" self,
obj: Any,
def __init__(self, obj: Any, receiver: _TestMessageReceiver) -> None: receiver: _TestMessageReceiver,
) -> None:
assert obj is not None assert obj is not None
self._obj = obj self._obj = obj
self._receiver = receiver self._receiver = receiver
@ -145,12 +147,14 @@ class _BoundTestMessageReceiver:
return self._receiver.handle_raw_message(self._obj, message) return self._receiver.handle_raw_message(self._obj, message)
# RCV_CODE_TEST_END
TEST_PROTOCOL = MessageProtocol( TEST_PROTOCOL = MessageProtocol(
message_types={ message_types={
1: _TestMessage1, 1: _TMessage1,
2: _TestMessage2, 2: _TMessage2,
3: _TestMessageR1, 3: _TResponse1,
4: _TestMessageR2, 4: _TResponse2,
}, },
trusted_client=True, trusted_client=True,
log_remote_exceptions=False, log_remote_exceptions=False,
@ -160,20 +164,61 @@ TEST_PROTOCOL = MessageProtocol(
def test_protocol_creation() -> None: def test_protocol_creation() -> None:
"""Test protocol creation.""" """Test protocol creation."""
# This should fail because _TestMessage1 can return _TestMessageR1 which # This should fail because _TMessage1 can return _TResponse1 which
# is not given an id here. # is not given an id here.
with pytest.raises(ValueError): with pytest.raises(ValueError):
_protocol = MessageProtocol(message_types={1: _TestMessage1}) _protocol = MessageProtocol(message_types={1: _TMessage1})
# Now it should work. # Now it should work.
_protocol = MessageProtocol(message_types={ _protocol = MessageProtocol(message_types={1: _TMessage1, 2: _TResponse1})
1: _TestMessage1,
2: _TestMessageR1
}) def test_sender_module_creation() -> None:
"""Test generation of protocol-specific sender modules for typing/etc."""
smod = TEST_PROTOCOL.create_sender_module('Test', private=True)
# Clip everything up to our first class declaration.
lines = smod.splitlines()
classline = lines.index('class _TestMessageSender(MessageSender):')
clipped = '\n'.join(lines[classline:])
# This snippet should match what we've got embedded above;
# If not then we need to update our test code.
with open(__file__, encoding='utf-8') as infile:
ourcode = infile.read()
emb = f'# SEND_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# SEND_CODE_TEST_END\n'
if emb not in ourcode:
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
raise RuntimeError('Generated sender module does not match embedded;'
' test code needs to be updated.'
' See test stdout for new code.')
def test_receiver_module_creation() -> None:
"""Test generation of protocol-specific sender modules for typing/etc."""
smod = TEST_PROTOCOL.create_receiver_module('Test', private=True)
# Clip everything up to our first class declaration.
lines = smod.splitlines()
classline = lines.index('class _TestMessageReceiver(MessageReceiver):')
clipped = '\n'.join(lines[classline:])
# This snippet should match what we've got embedded above;
# If not then we need to update our test code.
with open(__file__, encoding='utf-8') as infile:
ourcode = infile.read()
emb = f'# RCV_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCV_CODE_TEST_END\n'
if emb not in ourcode:
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
raise RuntimeError('Generated sender module does not match embedded;'
' test code needs to be updated.'
' See test stdout for new code.')
def test_receiver_creation() -> None: def test_receiver_creation() -> None:
"""Test receiver creation""" """Test receiver creation."""
# This should fail due to the registered handler only specifying # This should fail due to the registered handler only specifying
# one response message type while the message type itself # one response message type while the message type itself
@ -186,12 +231,10 @@ def test_receiver_creation() -> None:
receiver = _TestMessageReceiver(TEST_PROTOCOL) receiver = _TestMessageReceiver(TEST_PROTOCOL)
@receiver.handler @receiver.handler
def handle_test_message_2(self, def handle_test_message_2(self, msg: _TMessage2) -> _TResponse2:
msg: _TestMessage2) -> _TestMessageR2:
"""Test.""" """Test."""
del msg # Unused del msg # Unused
print('Hello from test message 1 handler!') return _TResponse2(fval=1.2)
return _TestMessageR2(fval=1.2)
# Should fail because not all message types in the protocol are handled. # Should fail because not all message types in the protocol are handled.
with pytest.raises(TypeError): with pytest.raises(TypeError):
@ -200,7 +243,7 @@ def test_receiver_creation() -> None:
"""Test class incorporating receive functionality.""" """Test class incorporating receive functionality."""
receiver = _TestMessageReceiver(TEST_PROTOCOL) receiver = _TestMessageReceiver(TEST_PROTOCOL)
receiver.validate_handler_completeness() receiver.validate()
def test_message_sending() -> None: def test_message_sending() -> None:
@ -226,42 +269,40 @@ def test_message_sending() -> None:
receiver = _TestMessageReceiver(TEST_PROTOCOL) receiver = _TestMessageReceiver(TEST_PROTOCOL)
@receiver.handler @receiver.handler
def handle_test_message_1(self, msg: _TestMessage1) -> _TestMessageR1: def handle_test_message_1(self, msg: _TMessage1) -> _TResponse1:
"""Test.""" """Test."""
print('Hello from test message 1 handler!')
if msg.ival == 1: if msg.ival == 1:
raise CleanError('Testing Clean Error') raise CleanError('Testing Clean Error')
if msg.ival == 2: if msg.ival == 2:
raise RuntimeError('Testing Runtime Error') raise RuntimeError('Testing Runtime Error')
return _TestMessageR1(bval=True) return _TResponse1(bval=True)
@receiver.handler @receiver.handler
def handle_test_message_2( def handle_test_message_2(
self, self, msg: _TMessage2) -> Union[_TResponse1, _TResponse2]:
msg: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
"""Test.""" """Test."""
del msg # Unused del msg # Unused
print('Hello from test message 2 handler!') return _TResponse2(fval=1.2)
return _TestMessageR2(fval=1.2)
receiver.validate_handler_completeness() receiver.validate()
obj_r = TestClassR() obj_r = TestClassR()
obj_s = TestClassS(target=obj_r) obj_s = TestClassS(target=obj_r)
response = obj_s.msg.send(_TestMessage1(ival=0)) if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
response2 = obj_s.msg.send(_TestMessage2(sval='rah')) response = obj_s.msg.send(_TMessage1(ival=0))
assert static_type_equals(response, _TestMessageR1) response2 = obj_s.msg.send(_TMessage2(sval='rah'))
assert isinstance(response, _TestMessageR1) assert static_type_equals(response, _TResponse1)
assert isinstance(response2, (_TestMessageR1, _TestMessageR2)) assert isinstance(response, _TResponse1)
assert isinstance(response2, (_TResponse1, _TResponse2))
# Remote CleanErrors should come across locally as the same. # Remote CleanErrors should come across locally as the same.
try: try:
_response3 = obj_s.msg.send(_TestMessage1(ival=1)) _response3 = obj_s.msg.send(_TMessage1(ival=1))
except Exception as exc: except Exception as exc:
assert isinstance(exc, CleanError) assert isinstance(exc, CleanError)
assert str(exc) == 'Testing Clean Error' assert str(exc) == 'Testing Clean Error'
# Other remote errors should come across as RemoteError. # Other remote errors should come across as RemoteError.
with pytest.raises(RemoteError): with pytest.raises(RemoteError):
_response4 = obj_s.msg.send(_TestMessage1(ival=2)) _response4 = obj_s.msg.send(_TMessage1(ival=2))

View File

@ -22,7 +22,7 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set, from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
Sequence) Sequence, Union)
from efro.error import CommunicationError from efro.error import CommunicationError
TM = TypeVar('TM', bound='MessageSender') TM = TypeVar('TM', bound='MessageSender')
@ -35,21 +35,24 @@ class RemoteErrorType(Enum):
class Message: class Message:
"""Base class for messages and their responses.""" """Base class for messages."""
@classmethod @classmethod
def get_response_types(cls) -> List[Type[Message]]: def get_response_types(cls) -> List[Type[Response]]:
"""Return all message types this Message can result in when sent. """Return all message types this Message can result in when sent.
Messages intended only for response types can leave this empty.
Note: RemoteErrorMessage is handled transparently and does not Note: RemoteErrorMessage is handled transparently and does not
need to be specified here. need to be specified here.
""" """
return [] return []
class Response:
"""Base class for responses to messages."""
@ioprepped @ioprepped
@dataclass @dataclass
class RemoteErrorMessage(Message): class RemoteErrorMessage(Response):
"""Message saying some error has occurred on the other end.""" """Message saying some error has occurred on the other end."""
error_message: Annotated[str, IOAttrs('m')] error_message: Annotated[str, IOAttrs('m')]
error_type: Annotated[RemoteErrorType, IOAttrs('t')] error_type: Annotated[RemoteErrorType, IOAttrs('t')]
@ -64,14 +67,13 @@ class MessageProtocol:
""" """
def __init__(self, def __init__(self,
message_types: Dict[int, Type[Message]], message_types: Dict[int, Union[Type[Message],
Type[Response]]],
type_key: Optional[str] = None, type_key: Optional[str] = None,
preserve_clean_errors: bool = True, preserve_clean_errors: bool = True,
log_remote_exceptions: bool = True, log_remote_exceptions: bool = True,
trusted_client: bool = False) -> None: trusted_client: bool = False) -> None:
"""Create a protocol with a given configuration. """Create a protocol with a given configuration.
Each entry for message_types should contain an ID, a message type,
and all possible response types.
If 'type_key' is provided, the message type ID is stored as the If 'type_key' is provided, the message type ID is stored as the
provided key in the message dict; otherwise it will be stored as provided key in the message dict; otherwise it will be stored as
@ -86,8 +88,10 @@ class MessageProtocol:
be included in the RemoteError. This should only be enabled in cases be included in the RemoteError. This should only be enabled in cases
where the client is trusted. where the client is trusted.
""" """
self.message_types_by_id: Dict[int, Type[Message]] = {} self.message_types_by_id: Dict[int, Union[Type[Message],
self.message_ids_by_type: Dict[Type[Message], int] = {} Type[Response]]] = {}
self.message_ids_by_type: Dict[Union[Type[Message], Type[Response]],
int] = {}
for m_id, m_type in message_types.items(): for m_id, m_type in message_types.items():
# Make sure only valid message types were passed and each # Make sure only valid message types were passed and each
@ -95,34 +99,47 @@ class MessageProtocol:
assert isinstance(m_id, int) assert isinstance(m_id, int)
assert m_id >= 0 assert m_id >= 0
assert (is_ioprepped_dataclass(m_type) assert (is_ioprepped_dataclass(m_type)
and issubclass(m_type, Message)) and issubclass(m_type, (Message, Response)))
assert self.message_types_by_id.get(m_id) is None assert self.message_types_by_id.get(m_id) is None
self.message_types_by_id[m_id] = m_type self.message_types_by_id[m_id] = m_type
self.message_ids_by_type[m_type] = m_id self.message_ids_by_type[m_type] = m_id
# Make sure all return types are valid and have been assigned # Some extra-thorough validation in debug mode.
# an ID as well.
if __debug__: if __debug__:
all_response_types: Set[Type[Message]] = set() # Make sure all return types are valid and have been assigned
# an ID as well.
all_response_types: Set[Type[Response]] = set()
for m_id, m_type in message_types.items(): for m_id, m_type in message_types.items():
m_rtypes = m_type.get_response_types() if issubclass(m_type, Message):
assert isinstance(m_rtypes, list) m_rtypes = m_type.get_response_types()
assert len(set(m_rtypes)) == len(m_rtypes) # check for dups assert isinstance(m_rtypes, list)
all_response_types.update(m_rtypes) assert m_rtypes # make sure not empty
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
all_response_types.update(m_rtypes)
for cls in all_response_types: for cls in all_response_types:
assert is_ioprepped_dataclass(cls) and issubclass(cls, Message) assert is_ioprepped_dataclass(cls) and issubclass(
cls, (Message, Response))
if cls not in self.message_ids_by_type: if cls not in self.message_ids_by_type:
raise ValueError(f'Possible response type {cls}' raise ValueError(f'Possible response type {cls}'
f' was not included in message_types.') f' was not included in message_types.')
# Make sure all registered types have unique base names.
# We can take advantage of this to generate cleaner looking
# protocol modules. Can revisit if this is ever a problem.
mtypenames = set(tp.__name__ for tp in self.message_ids_by_type)
if len(mtypenames) != len(message_types):
raise ValueError(
'message_types contains duplicate __name__s;'
' all types are required to have unique names.')
self._type_key = type_key self._type_key = type_key
self.preserve_clean_errors = preserve_clean_errors self.preserve_clean_errors = preserve_clean_errors
self.log_remote_exceptions = log_remote_exceptions self.log_remote_exceptions = log_remote_exceptions
self.trusted_client = trusted_client self.trusted_client = trusted_client
def message_encode(self, def message_encode(self,
message: Message, message: Union[Message, Response],
is_error: bool = False) -> bytes: is_error: bool = False) -> bytes:
"""Encode a message to bytes for transport.""" """Encode a message to bytes for transport."""
@ -132,8 +149,9 @@ class MessageProtocol:
else: else:
m_id = self.message_ids_by_type.get(type(message)) m_id = self.message_ids_by_type.get(type(message))
if m_id is None: if m_id is None:
raise TypeError(f'Message type is not registered in Protocol:' raise TypeError(
f' {type(message)}') f'Message/Response type is not registered in Protocol:'
f' {type(message)}')
msgdict = dataclass_to_dict(message) msgdict = dataclass_to_dict(message)
# Encode type as part of the message dict if desired # Encode type as part of the message dict if desired
@ -148,7 +166,7 @@ class MessageProtocol:
out = {'m': msgdict, 't': m_id} out = {'m': msgdict, 't': m_id}
return json.dumps(out, separators=(',', ':')).encode() return json.dumps(out, separators=(',', ':')).encode()
def message_decode(self, data: bytes) -> Message: def message_decode(self, data: bytes) -> Union[Message, Response]:
"""Decode a message from bytes. """Decode a message from bytes.
If the message represents a remote error, an Exception will If the message represents a remote error, an Exception will
@ -178,25 +196,200 @@ class MessageProtocol:
# Decode this particular type and make sure its valid. # Decode this particular type and make sure its valid.
msgtype = self.message_types_by_id.get(m_id) msgtype = self.message_types_by_id.get(m_id)
if msgtype is None: if msgtype is None:
raise TypeError(f'Got unregistered message type id of {m_id}.') raise TypeError(
f'Got unregistered message/response type id of {m_id}.')
return dataclass_from_dict(msgtype, msgdict) out = dataclass_from_dict(msgtype, msgdict)
assert isinstance(out, (Message, Response))
return out
def create_sender_module(self, classname: str) -> str: def _get_module_header(self, part: str) -> str:
"""Return common parts of generated modules."""
imports: Dict[str, List[str]] = {}
for msgtype in self.message_ids_by_type:
imports.setdefault(msgtype.__module__, []).append(msgtype.__name__)
importlines = ''
for module, names in sorted(imports.items()):
jnames = ', '.join(names)
line = f'from {module} import {jnames}'
if len(line) > 79:
# Recreate in a wrapping-friendly form.
line = f'from {module} import ({jnames})'
importlines += f'{line}\n'
if part == 'sender':
importlines = (
f'from efro.message import MessageSender\n{importlines}')
tpimports = 'from efro.message import Message, Response'
else:
importlines = (
f'from efro.message import MessageReceiver\n{importlines}')
tpimports = 'from efro.message import Message, Response'
out = ('# Released under the MIT License. See LICENSE for details.\n'
f'#\n'
f'"""Auto-generated {part} module."""\n'
f'\n'
f'from __future__ import annotations\n'
f'\n'
f'from typing import TYPE_CHECKING, overload\n'
f'\n'
f'{importlines}'
f'\n'
f'if TYPE_CHECKING:\n'
f' from typing import Union\n'
f' {tpimports}\n'
f'\n'
f'\n')
return out
def create_sender_module(self,
classname: str,
private: bool = False) -> str:
""""Create a Python module defining a MessageSender subclass. """"Create a Python module defining a MessageSender subclass.
This class is primarily for type checking and will contain overrides This class is primarily for type checking and will contain overrides
for the varieties of send calls for message/response types defined for the varieties of send calls for message/response types defined
in the protocol. in the protocol.
Note that line lengths are not clipped, so output may need to be
run through a formatter to prevent lint warnings about excessive
line lengths.
""" """
def create_receiver_module(self, classname: str) -> str: ppre = '_' if private else ''
out = self._get_module_header('sender')
out += (f'class {ppre}{classname}MessageSender(MessageSender):\n'
f' """Protocol-specific sender."""\n'
f'\n'
f' def __get__(self,\n'
f' obj: Any,\n'
f' type_in: Any = None)'
f' -> {ppre}Bound{classname}MessageSender:\n'
f' return {ppre}Bound{classname}MessageSender'
f'(obj, self)\n'
f'\n'
f'\n'
f'class {ppre}Bound{classname}MessageSender:\n'
f' """Protocol-specific bound sender."""\n'
f'\n'
f' def __init__(self, obj: Any,'
f' sender: {ppre}{classname}MessageSender) -> None:\n'
f' assert obj is not None\n'
f' self._obj = obj\n'
f' self._sender = sender\n')
# Define handler() overloads for all registered message types.
msgtypes = [
t for t in self.message_ids_by_type if issubclass(t, Message)
]
# Ew; @overload requires at least 2 different signatures so
# we need to simply write a single function if we have < 2.
if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently require at least 2'
' message types.')
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
tps = ', '.join(t.__name__ for t in rtypes)
responsetypevar = f'Union[{tps}]'
else:
responsetypevar = rtypes[0].__name__
out += (f'\n'
f' @overload\n'
f' def send(self, message: {msgtypevar})'
f' -> {responsetypevar}:\n'
f' ...\n')
out += ('\n'
' def send(self, message: Message) -> Response:\n'
' """Send a message."""\n'
' return self._sender.send(self._obj, message)\n')
return out
def create_receiver_module(self,
classname: str,
private: bool = False) -> str:
""""Create a Python module defining a MessageReceiver subclass. """"Create a Python module defining a MessageReceiver subclass.
This class is primarily for type checking and will contain overrides This class is primarily for type checking and will contain overrides
for the register method for message/response types defined in for the register method for message/response types defined in
the protocol. the protocol.
Note that line lengths are not clipped, so output may need to be
run through a formatter to prevent lint warnings about excessive
line lengths.
""" """
ppre = '_' if private else ''
out = self._get_module_header('receiver')
out += (f'class {ppre}{classname}MessageReceiver(MessageReceiver):\n'
f' """Protocol-specific receiver."""\n'
f'\n'
f' def __get__(\n'
f' self,\n'
f' obj: Any,\n'
f' type_in: Any = None,\n'
f' ) -> {ppre}Bound{classname}MessageReceiver:\n'
f' return {ppre}Bound{classname}MessageReceiver('
f'obj, self)\n')
# Define handler() overloads for all registered message types.
msgtypes = [
t for t in self.message_ids_by_type if issubclass(t, Message)
]
# Ew; @overload requires at least 2 different signatures so
# we need to simply write a single function if we have < 2.
if len(msgtypes) == 1:
raise RuntimeError('FIXME: currently require at least 2'
' message types.')
if len(msgtypes) > 1:
for msgtype in msgtypes:
msgtypevar = msgtype.__name__
rtypes = msgtype.get_response_types()
if len(rtypes) > 1:
tps = ', '.join(t.__name__ for t in rtypes)
rtypevar = f'Union[{tps}]'
else:
rtypevar = rtypes[0].__name__
out += (
f'\n'
f' @overload\n'
f' def handler(\n'
f' self,\n'
f' call: Callable[[Any, {msgtypevar}], '
f'{rtypevar}],\n'
f' ) -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
f' ...\n')
out += ('\n'
' def handler(self, call: Callable) -> Callable:\n'
' """Decorator to register message handlers."""\n'
' self.register_handler(call)\n'
' return call\n')
out += (f'\n'
f'\n'
f'class {ppre}Bound{classname}MessageReceiver:\n'
f' """Protocol-specific bound receiver."""\n'
f'\n'
f' def __init__(\n'
f' self,\n'
f' obj: Any,\n'
f' receiver: _TestMessageReceiver,\n'
f' ) -> None:\n'
f' assert obj is not None\n'
f' self._obj = obj\n'
f' self._receiver = receiver\n'
f'\n'
f' def handle_raw_message(self, message: bytes) -> bytes:\n'
f' """Handle a raw incoming message."""\n'
f' return self._receiver.handle_raw_message'
f'(self._obj, message)\n')
return out
class MessageSender: class MessageSender:
@ -232,7 +425,7 @@ class MessageSender:
self._send_raw_message_call = call self._send_raw_message_call = call
return call return call
def send(self, bound_obj: Any, message: Message) -> Message: def send(self, bound_obj: Any, message: Message) -> Response:
"""Send a message and receive a response. """Send a message and receive a response.
Will encode the message for transport and call dispatch_raw_message() Will encode the message for transport and call dispatch_raw_message()
@ -240,13 +433,10 @@ class MessageSender:
if self._send_raw_message_call is None: if self._send_raw_message_call is None:
raise RuntimeError('send() is unimplemented for this type.') raise RuntimeError('send() is unimplemented for this type.')
# Only types with possible response types should ever be sent.
assert type(message).get_response_types()
msg_encoded = self._protocol.message_encode(message) msg_encoded = self._protocol.message_encode(message)
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded) response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
response = self._protocol.message_decode(response_encoded) response = self._protocol.message_decode(response_encoded)
assert isinstance(response, Message) assert isinstance(response, Response)
assert type(response) in type(message).get_response_types() assert type(response) in type(message).get_response_types()
return response return response
@ -278,10 +468,10 @@ class MessageReceiver:
class MyClass: class MyClass:
receiver = MyMessageReceiver() receiver = MyMessageReceiver()
# MyMessageReceiver should provide overloads to register_handler() # MyMessageReceiver fills out handler() overloads to ensure all
# to ensure all registered handlers have valid types/return-types. # registered handlers have valid types/return-types.
@receiver.handler @receiver.handler
def handle_some_message_type(self, message: SomeType) -> AnotherType: def handle_some_message_type(self, message: SomeMsg) -> SomeResponse:
# Deal with this message type here. # Deal with this message type here.
# This will trigger the registered handler being called. # This will trigger the registered handler being called.
@ -298,7 +488,7 @@ class MessageReceiver:
# noinspection PyProtectedMember # noinspection PyProtectedMember
def register_handler(self, call: Callable[[Any, Message], def register_handler(self, call: Callable[[Any, Message],
Message]) -> None: Response]) -> None:
"""Register a handler call. """Register a handler call.
The message type handled by the call is determined by its The message type handled by the call is determined by its
@ -366,14 +556,10 @@ class MessageReceiver:
# Ok; we're good! # Ok; we're good!
self._handlers[msgtype] = call self._handlers[msgtype] = call
def validate_handler_completeness(self, warn_only: bool = False) -> None: def validate(self, warn_only: bool = False) -> None:
"""Return whether this receiver handles all protocol messages. """Check for handler completeness, valid types, etc."""
Only messages having possible response types are considered, as
those are the only ones that can be sent to a receiver.
"""
for msgtype in self._protocol.message_ids_by_type.keys(): for msgtype in self._protocol.message_ids_by_type.keys():
if not msgtype.get_response_types(): if issubclass(msgtype, Response):
continue continue
if msgtype not in self._handlers: if msgtype not in self._handlers:
msg = (f'Protocol message {msgtype} not handled' msg = (f'Protocol message {msgtype} not handled'
@ -388,6 +574,7 @@ class MessageReceiver:
# Decode the incoming message. # Decode the incoming message.
msg_decoded = self._protocol.message_decode(msg) msg_decoded = self._protocol.message_decode(msg)
msgtype = type(msg_decoded) msgtype = type(msg_decoded)
assert issubclass(msgtype, Message)
# Call the proper handler. # Call the proper handler.
handler = self._handlers.get(msgtype) handler = self._handlers.get(msgtype)
@ -396,7 +583,7 @@ class MessageReceiver:
response = handler(bound_obj, msg_decoded) response = handler(bound_obj, msg_decoded)
# Re-encode the response. # Re-encode the response.
assert isinstance(response, Message) assert isinstance(response, Response)
assert type(response) in msgtype.get_response_types() assert type(response) in msgtype.get_response_types()
return self._protocol.message_encode(response) return self._protocol.message_encode(response)