mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-02-05 23:13:46 +08:00
module generation for new messaging stuff
This commit is contained in:
parent
14c1a20ad0
commit
95bbb89d14
12
.idea/dictionaries/ericf.xml
generated
12
.idea/dictionaries/ericf.xml
generated
@ -361,6 +361,7 @@
|
|||||||
<w>chromebooks</w>
|
<w>chromebooks</w>
|
||||||
<w>chunksize</w>
|
<w>chunksize</w>
|
||||||
<w>cjkcodecs</w>
|
<w>cjkcodecs</w>
|
||||||
|
<w>classline</w>
|
||||||
<w>classmethod</w>
|
<w>classmethod</w>
|
||||||
<w>classmethods</w>
|
<w>classmethods</w>
|
||||||
<w>classname</w>
|
<w>classname</w>
|
||||||
@ -1060,6 +1061,7 @@
|
|||||||
<w>imgh</w>
|
<w>imgh</w>
|
||||||
<w>imghdr</w>
|
<w>imghdr</w>
|
||||||
<w>imgw</w>
|
<w>imgw</w>
|
||||||
|
<w>importlines</w>
|
||||||
<w>incentivized</w>
|
<w>incentivized</w>
|
||||||
<w>includetest</w>
|
<w>includetest</w>
|
||||||
<w>incmd</w>
|
<w>incmd</w>
|
||||||
@ -1135,6 +1137,7 @@
|
|||||||
<w>jisx</w>
|
<w>jisx</w>
|
||||||
<w>jite</w>
|
<w>jite</w>
|
||||||
<w>jittering</w>
|
<w>jittering</w>
|
||||||
|
<w>jnames</w>
|
||||||
<w>joedeshon</w>
|
<w>joedeshon</w>
|
||||||
<w>johab</w>
|
<w>johab</w>
|
||||||
<w>joinable</w>
|
<w>joinable</w>
|
||||||
@ -1406,6 +1409,8 @@
|
|||||||
<w>msgdict</w>
|
<w>msgdict</w>
|
||||||
<w>msgfull</w>
|
<w>msgfull</w>
|
||||||
<w>msgtype</w>
|
<w>msgtype</w>
|
||||||
|
<w>msgtypes</w>
|
||||||
|
<w>msgtypevar</w>
|
||||||
<w>mshell</w>
|
<w>mshell</w>
|
||||||
<w>msvccompiler</w>
|
<w>msvccompiler</w>
|
||||||
<w>msvcp</w>
|
<w>msvcp</w>
|
||||||
@ -1414,6 +1419,7 @@
|
|||||||
<w>mtrans</w>
|
<w>mtrans</w>
|
||||||
<w>mtvos</w>
|
<w>mtvos</w>
|
||||||
<w>mtype</w>
|
<w>mtype</w>
|
||||||
|
<w>mtypenames</w>
|
||||||
<w>mult</w>
|
<w>mult</w>
|
||||||
<w>multibytecodec</w>
|
<w>multibytecodec</w>
|
||||||
<w>multikillcount</w>
|
<w>multikillcount</w>
|
||||||
@ -1570,6 +1576,7 @@
|
|||||||
<w>osval</w>
|
<w>osval</w>
|
||||||
<w>otherplayer</w>
|
<w>otherplayer</w>
|
||||||
<w>otherspawn</w>
|
<w>otherspawn</w>
|
||||||
|
<w>ourcode</w>
|
||||||
<w>ourhash</w>
|
<w>ourhash</w>
|
||||||
<w>ourname</w>
|
<w>ourname</w>
|
||||||
<w>ourself</w>
|
<w>ourself</w>
|
||||||
@ -1699,6 +1706,7 @@
|
|||||||
<w>poweruptype</w>
|
<w>poweruptype</w>
|
||||||
<w>powervr</w>
|
<w>powervr</w>
|
||||||
<w>ppos</w>
|
<w>ppos</w>
|
||||||
|
<w>ppre</w>
|
||||||
<w>pproxy</w>
|
<w>pproxy</w>
|
||||||
<w>pptabcom</w>
|
<w>pptabcom</w>
|
||||||
<w>pragmas</w>
|
<w>pragmas</w>
|
||||||
@ -1886,6 +1894,7 @@
|
|||||||
<w>respawnicon</w>
|
<w>respawnicon</w>
|
||||||
<w>responsetype</w>
|
<w>responsetype</w>
|
||||||
<w>responsetypes</w>
|
<w>responsetypes</w>
|
||||||
|
<w>responsetypevar</w>
|
||||||
<w>resultstr</w>
|
<w>resultstr</w>
|
||||||
<w>retrysecs</w>
|
<w>retrysecs</w>
|
||||||
<w>returncode</w>
|
<w>returncode</w>
|
||||||
@ -1921,6 +1930,7 @@
|
|||||||
<w>rtnetlink</w>
|
<w>rtnetlink</w>
|
||||||
<w>rtxt</w>
|
<w>rtxt</w>
|
||||||
<w>rtypes</w>
|
<w>rtypes</w>
|
||||||
|
<w>rtypevar</w>
|
||||||
<w>runmypy</w>
|
<w>runmypy</w>
|
||||||
<w>runonly</w>
|
<w>runonly</w>
|
||||||
<w>runpy</w>
|
<w>runpy</w>
|
||||||
@ -2053,6 +2063,7 @@
|
|||||||
<w>smag</w>
|
<w>smag</w>
|
||||||
<w>smallscale</w>
|
<w>smallscale</w>
|
||||||
<w>smlh</w>
|
<w>smlh</w>
|
||||||
|
<w>smod</w>
|
||||||
<w>smoothstep</w>
|
<w>smoothstep</w>
|
||||||
<w>smoothy</w>
|
<w>smoothy</w>
|
||||||
<w>smtpd</w>
|
<w>smtpd</w>
|
||||||
@ -2333,6 +2344,7 @@
|
|||||||
<w>touchpad</w>
|
<w>touchpad</w>
|
||||||
<w>tournamententry</w>
|
<w>tournamententry</w>
|
||||||
<w>tournamentscores</w>
|
<w>tournamentscores</w>
|
||||||
|
<w>tpimports</w>
|
||||||
<w>tplayer</w>
|
<w>tplayer</w>
|
||||||
<w>tpos</w>
|
<w>tpos</w>
|
||||||
<w>tproxy</w>
|
<w>tproxy</w>
|
||||||
|
|||||||
12
ballisticacore-cmake/.idea/dictionaries/ericf.xml
generated
12
ballisticacore-cmake/.idea/dictionaries/ericf.xml
generated
@ -179,6 +179,7 @@
|
|||||||
<w>chunksize</w>
|
<w>chunksize</w>
|
||||||
<w>cjief</w>
|
<w>cjief</w>
|
||||||
<w>classdict</w>
|
<w>classdict</w>
|
||||||
|
<w>classline</w>
|
||||||
<w>cleanupcheck</w>
|
<w>cleanupcheck</w>
|
||||||
<w>clientid</w>
|
<w>clientid</w>
|
||||||
<w>clientinfo</w>
|
<w>clientinfo</w>
|
||||||
@ -493,6 +494,7 @@
|
|||||||
<w>illum</w>
|
<w>illum</w>
|
||||||
<w>ilock</w>
|
<w>ilock</w>
|
||||||
<w>imagewidget</w>
|
<w>imagewidget</w>
|
||||||
|
<w>importlines</w>
|
||||||
<w>incentivized</w>
|
<w>incentivized</w>
|
||||||
<w>inet</w>
|
<w>inet</w>
|
||||||
<w>infotxt</w>
|
<w>infotxt</w>
|
||||||
@ -533,6 +535,7 @@
|
|||||||
<w>jaxis</w>
|
<w>jaxis</w>
|
||||||
<w>jcjwf</w>
|
<w>jcjwf</w>
|
||||||
<w>jmessage</w>
|
<w>jmessage</w>
|
||||||
|
<w>jnames</w>
|
||||||
<w>keepalives</w>
|
<w>keepalives</w>
|
||||||
<w>keyanntype</w>
|
<w>keyanntype</w>
|
||||||
<w>keycode</w>
|
<w>keycode</w>
|
||||||
@ -643,6 +646,9 @@
|
|||||||
<w>msgdict</w>
|
<w>msgdict</w>
|
||||||
<w>msgfull</w>
|
<w>msgfull</w>
|
||||||
<w>msgtype</w>
|
<w>msgtype</w>
|
||||||
|
<w>msgtypes</w>
|
||||||
|
<w>msgtypevar</w>
|
||||||
|
<w>mtypenames</w>
|
||||||
<w>mult</w>
|
<w>mult</w>
|
||||||
<w>multing</w>
|
<w>multing</w>
|
||||||
<w>multipass</w>
|
<w>multipass</w>
|
||||||
@ -746,6 +752,7 @@
|
|||||||
<w>osis</w>
|
<w>osis</w>
|
||||||
<w>osssssssssss</w>
|
<w>osssssssssss</w>
|
||||||
<w>ostype</w>
|
<w>ostype</w>
|
||||||
|
<w>ourcode</w>
|
||||||
<w>ourname</w>
|
<w>ourname</w>
|
||||||
<w>ourself</w>
|
<w>ourself</w>
|
||||||
<w>ourstanding</w>
|
<w>ourstanding</w>
|
||||||
@ -781,6 +788,7 @@
|
|||||||
<w>postinit</w>
|
<w>postinit</w>
|
||||||
<w>postrun</w>
|
<w>postrun</w>
|
||||||
<w>powerup</w>
|
<w>powerup</w>
|
||||||
|
<w>ppre</w>
|
||||||
<w>pptabcom</w>
|
<w>pptabcom</w>
|
||||||
<w>precalc</w>
|
<w>precalc</w>
|
||||||
<w>predeclare</w>
|
<w>predeclare</w>
|
||||||
@ -872,6 +880,7 @@
|
|||||||
<w>resetbtn</w>
|
<w>resetbtn</w>
|
||||||
<w>resetinput</w>
|
<w>resetinput</w>
|
||||||
<w>responsetypes</w>
|
<w>responsetypes</w>
|
||||||
|
<w>responsetypevar</w>
|
||||||
<w>resync</w>
|
<w>resync</w>
|
||||||
<w>retrysecs</w>
|
<w>retrysecs</w>
|
||||||
<w>retval</w>
|
<w>retval</w>
|
||||||
@ -889,6 +898,7 @@
|
|||||||
<w>rscode</w>
|
<w>rscode</w>
|
||||||
<w>rsgc</w>
|
<w>rsgc</w>
|
||||||
<w>rtypes</w>
|
<w>rtypes</w>
|
||||||
|
<w>rtypevar</w>
|
||||||
<w>runnables</w>
|
<w>runnables</w>
|
||||||
<w>rvec</w>
|
<w>rvec</w>
|
||||||
<w>rvel</w>
|
<w>rvel</w>
|
||||||
@ -945,6 +955,7 @@
|
|||||||
<w>simpletype</w>
|
<w>simpletype</w>
|
||||||
<w>sisssssssss</w>
|
<w>sisssssssss</w>
|
||||||
<w>sixteenbits</w>
|
<w>sixteenbits</w>
|
||||||
|
<w>smod</w>
|
||||||
<w>smoothering</w>
|
<w>smoothering</w>
|
||||||
<w>smoothstep</w>
|
<w>smoothstep</w>
|
||||||
<w>smoothy</w>
|
<w>smoothy</w>
|
||||||
@ -1053,6 +1064,7 @@
|
|||||||
<w>touchpad</w>
|
<w>touchpad</w>
|
||||||
<w>toucs</w>
|
<w>toucs</w>
|
||||||
<w>toutf</w>
|
<w>toutf</w>
|
||||||
|
<w>tpimports</w>
|
||||||
<w>tracebacks</w>
|
<w>tracebacks</w>
|
||||||
<w>tracestr</w>
|
<w>tracestr</w>
|
||||||
<w>trackpad</w>
|
<w>trackpad</w>
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
|
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
|
||||||
<h4><em>last updated on 2021-09-07 for Ballistica version 1.6.5 build 20391</em></h4>
|
<h4><em>last updated on 2021-09-08 for Ballistica version 1.6.5 build 20391</em></h4>
|
||||||
<p>This page documents the Python classes and functions in the 'ba' module,
|
<p>This page documents the Python classes and functions in the 'ba' module,
|
||||||
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
|
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
|
||||||
<hr>
|
<hr>
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import TYPE_CHECKING, overload
|
from typing import TYPE_CHECKING, overload
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@ -11,7 +12,7 @@ import pytest
|
|||||||
|
|
||||||
from efro.error import CleanError, RemoteError
|
from efro.error import CleanError, RemoteError
|
||||||
from efro.dataclassio import ioprepped
|
from efro.dataclassio import ioprepped
|
||||||
from efro.message import (Message, MessageProtocol, MessageSender,
|
from efro.message import (Message, Response, MessageProtocol, MessageSender,
|
||||||
MessageReceiver)
|
MessageReceiver)
|
||||||
from efrotools.statictest import static_type_equals
|
from efrotools.statictest import static_type_equals
|
||||||
|
|
||||||
@ -21,52 +22,52 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TestMessage1(Message):
|
class _TMessage1(Message):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
ival: int
|
ival: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_response_types(cls) -> List[Type[Message]]:
|
def get_response_types(cls) -> List[Type[Response]]:
|
||||||
return [_TestMessageR1]
|
return [_TResponse1]
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TestMessage2(Message):
|
class _TMessage2(Message):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
sval: str
|
sval: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_response_types(cls) -> List[Type[Message]]:
|
def get_response_types(cls) -> List[Type[Response]]:
|
||||||
return [_TestMessageR1, _TestMessageR2]
|
return [_TResponse1, _TResponse2]
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TestMessageR1(Message):
|
class _TResponse1(Response):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
bval: bool
|
bval: bool
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TestMessageR2(Message):
|
class _TResponse2(Response):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
fval: float
|
fval: float
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class _TestMessageR3(Message):
|
class _TResponse3(Message):
|
||||||
"""Just testing."""
|
"""Just testing."""
|
||||||
fval: float
|
fval: float
|
||||||
|
|
||||||
|
|
||||||
class _TestMessageSender(MessageSender):
|
# SEND_CODE_TEST_BEGIN
|
||||||
"""Testing type overrides for message sending.
|
|
||||||
|
|
||||||
Normally this would be auto-generated based on the protocol.
|
|
||||||
"""
|
class _TestMessageSender(MessageSender):
|
||||||
|
"""Protocol-specific sender."""
|
||||||
|
|
||||||
def __get__(self,
|
def __get__(self,
|
||||||
obj: Any,
|
obj: Any,
|
||||||
@ -75,10 +76,7 @@ class _TestMessageSender(MessageSender):
|
|||||||
|
|
||||||
|
|
||||||
class _BoundTestMessageSender:
|
class _BoundTestMessageSender:
|
||||||
"""Testing type overrides for message sending.
|
"""Protocol-specific bound sender."""
|
||||||
|
|
||||||
Normally this would be auto-generated based on the protocol.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, obj: Any, sender: _TestMessageSender) -> None:
|
def __init__(self, obj: Any, sender: _TestMessageSender) -> None:
|
||||||
assert obj is not None
|
assert obj is not None
|
||||||
@ -86,56 +84,60 @@ class _BoundTestMessageSender:
|
|||||||
self._sender = sender
|
self._sender = sender
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def send(self, message: _TestMessage1) -> _TestMessageR1:
|
def send(self, message: _TMessage1) -> _TResponse1:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def send(self,
|
def send(self, message: _TMessage2) -> Union[_TResponse1, _TResponse2]:
|
||||||
message: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
|
|
||||||
...
|
...
|
||||||
|
|
||||||
def send(self, message: Message) -> Message:
|
def send(self, message: Message) -> Response:
|
||||||
"""Send a particular message type."""
|
"""Send a message."""
|
||||||
return self._sender.send(self._obj, message)
|
return self._sender.send(self._obj, message)
|
||||||
|
|
||||||
|
|
||||||
|
# SEND_CODE_TEST_END
|
||||||
|
# RCV_CODE_TEST_BEGIN
|
||||||
|
|
||||||
|
|
||||||
class _TestMessageReceiver(MessageReceiver):
|
class _TestMessageReceiver(MessageReceiver):
|
||||||
"""Testing type overrides for message receiving.
|
"""Protocol-specific receiver."""
|
||||||
|
|
||||||
Normally this would be auto-generated based on the protocol.
|
def __get__(
|
||||||
"""
|
self,
|
||||||
|
obj: Any,
|
||||||
def __get__(self,
|
type_in: Any = None,
|
||||||
obj: Any,
|
) -> _BoundTestMessageReceiver:
|
||||||
type_in: Any = None) -> _BoundTestMessageReceiver:
|
|
||||||
return _BoundTestMessageReceiver(obj, self)
|
return _BoundTestMessageReceiver(obj, self)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def handler(
|
def handler(
|
||||||
self, call: Callable[[Any, _TestMessage1], _TestMessageR1]
|
self,
|
||||||
) -> Callable[[Any, _TestMessage1], _TestMessageR1]:
|
call: Callable[[Any, _TMessage1], _TResponse1],
|
||||||
|
) -> Callable[[Any, _TMessage1], _TResponse1]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def handler(
|
def handler(
|
||||||
self, call: Callable[[Any, _TestMessage2], Union[_TestMessageR1,
|
self,
|
||||||
_TestMessageR2]]
|
call: Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]],
|
||||||
) -> Callable[[Any, _TestMessage2], Union[_TestMessageR1, _TestMessageR2]]:
|
) -> Callable[[Any, _TMessage2], Union[_TResponse1, _TResponse2]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def handler(self, call: Callable) -> Callable:
|
def handler(self, call: Callable) -> Callable:
|
||||||
"""Decorator to register a handler for a particular message type."""
|
"""Decorator to register message handlers."""
|
||||||
self.register_handler(call)
|
self.register_handler(call)
|
||||||
return call
|
return call
|
||||||
|
|
||||||
|
|
||||||
class _BoundTestMessageReceiver:
|
class _BoundTestMessageReceiver:
|
||||||
"""Testing type overrides for message receiving.
|
"""Protocol-specific bound receiver."""
|
||||||
|
|
||||||
Normally this would be auto-generated based on the protocol.
|
def __init__(
|
||||||
"""
|
self,
|
||||||
|
obj: Any,
|
||||||
def __init__(self, obj: Any, receiver: _TestMessageReceiver) -> None:
|
receiver: _TestMessageReceiver,
|
||||||
|
) -> None:
|
||||||
assert obj is not None
|
assert obj is not None
|
||||||
self._obj = obj
|
self._obj = obj
|
||||||
self._receiver = receiver
|
self._receiver = receiver
|
||||||
@ -145,12 +147,14 @@ class _BoundTestMessageReceiver:
|
|||||||
return self._receiver.handle_raw_message(self._obj, message)
|
return self._receiver.handle_raw_message(self._obj, message)
|
||||||
|
|
||||||
|
|
||||||
|
# RCV_CODE_TEST_END
|
||||||
|
|
||||||
TEST_PROTOCOL = MessageProtocol(
|
TEST_PROTOCOL = MessageProtocol(
|
||||||
message_types={
|
message_types={
|
||||||
1: _TestMessage1,
|
1: _TMessage1,
|
||||||
2: _TestMessage2,
|
2: _TMessage2,
|
||||||
3: _TestMessageR1,
|
3: _TResponse1,
|
||||||
4: _TestMessageR2,
|
4: _TResponse2,
|
||||||
},
|
},
|
||||||
trusted_client=True,
|
trusted_client=True,
|
||||||
log_remote_exceptions=False,
|
log_remote_exceptions=False,
|
||||||
@ -160,20 +164,61 @@ TEST_PROTOCOL = MessageProtocol(
|
|||||||
def test_protocol_creation() -> None:
|
def test_protocol_creation() -> None:
|
||||||
"""Test protocol creation."""
|
"""Test protocol creation."""
|
||||||
|
|
||||||
# This should fail because _TestMessage1 can return _TestMessageR1 which
|
# This should fail because _TMessage1 can return _TResponse1 which
|
||||||
# is not given an id here.
|
# is not given an id here.
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_protocol = MessageProtocol(message_types={1: _TestMessage1})
|
_protocol = MessageProtocol(message_types={1: _TMessage1})
|
||||||
|
|
||||||
# Now it should work.
|
# Now it should work.
|
||||||
_protocol = MessageProtocol(message_types={
|
_protocol = MessageProtocol(message_types={1: _TMessage1, 2: _TResponse1})
|
||||||
1: _TestMessage1,
|
|
||||||
2: _TestMessageR1
|
|
||||||
})
|
def test_sender_module_creation() -> None:
|
||||||
|
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||||
|
smod = TEST_PROTOCOL.create_sender_module('Test', private=True)
|
||||||
|
|
||||||
|
# Clip everything up to our first class declaration.
|
||||||
|
lines = smod.splitlines()
|
||||||
|
classline = lines.index('class _TestMessageSender(MessageSender):')
|
||||||
|
clipped = '\n'.join(lines[classline:])
|
||||||
|
|
||||||
|
# This snippet should match what we've got embedded above;
|
||||||
|
# If not then we need to update our test code.
|
||||||
|
with open(__file__, encoding='utf-8') as infile:
|
||||||
|
ourcode = infile.read()
|
||||||
|
|
||||||
|
emb = f'# SEND_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# SEND_CODE_TEST_END\n'
|
||||||
|
if emb not in ourcode:
|
||||||
|
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
|
||||||
|
raise RuntimeError('Generated sender module does not match embedded;'
|
||||||
|
' test code needs to be updated.'
|
||||||
|
' See test stdout for new code.')
|
||||||
|
|
||||||
|
|
||||||
|
def test_receiver_module_creation() -> None:
|
||||||
|
"""Test generation of protocol-specific sender modules for typing/etc."""
|
||||||
|
smod = TEST_PROTOCOL.create_receiver_module('Test', private=True)
|
||||||
|
|
||||||
|
# Clip everything up to our first class declaration.
|
||||||
|
lines = smod.splitlines()
|
||||||
|
classline = lines.index('class _TestMessageReceiver(MessageReceiver):')
|
||||||
|
clipped = '\n'.join(lines[classline:])
|
||||||
|
|
||||||
|
# This snippet should match what we've got embedded above;
|
||||||
|
# If not then we need to update our test code.
|
||||||
|
with open(__file__, encoding='utf-8') as infile:
|
||||||
|
ourcode = infile.read()
|
||||||
|
|
||||||
|
emb = f'# RCV_CODE_TEST_BEGIN\n\n\n{clipped}\n\n\n# RCV_CODE_TEST_END\n'
|
||||||
|
if emb not in ourcode:
|
||||||
|
print(f'EXPECTED EMBEDDED CODE:\n{emb}')
|
||||||
|
raise RuntimeError('Generated sender module does not match embedded;'
|
||||||
|
' test code needs to be updated.'
|
||||||
|
' See test stdout for new code.')
|
||||||
|
|
||||||
|
|
||||||
def test_receiver_creation() -> None:
|
def test_receiver_creation() -> None:
|
||||||
"""Test receiver creation"""
|
"""Test receiver creation."""
|
||||||
|
|
||||||
# This should fail due to the registered handler only specifying
|
# This should fail due to the registered handler only specifying
|
||||||
# one response message type while the message type itself
|
# one response message type while the message type itself
|
||||||
@ -186,12 +231,10 @@ def test_receiver_creation() -> None:
|
|||||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
||||||
|
|
||||||
@receiver.handler
|
@receiver.handler
|
||||||
def handle_test_message_2(self,
|
def handle_test_message_2(self, msg: _TMessage2) -> _TResponse2:
|
||||||
msg: _TestMessage2) -> _TestMessageR2:
|
|
||||||
"""Test."""
|
"""Test."""
|
||||||
del msg # Unused
|
del msg # Unused
|
||||||
print('Hello from test message 1 handler!')
|
return _TResponse2(fval=1.2)
|
||||||
return _TestMessageR2(fval=1.2)
|
|
||||||
|
|
||||||
# Should fail because not all message types in the protocol are handled.
|
# Should fail because not all message types in the protocol are handled.
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
@ -200,7 +243,7 @@ def test_receiver_creation() -> None:
|
|||||||
"""Test class incorporating receive functionality."""
|
"""Test class incorporating receive functionality."""
|
||||||
|
|
||||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
||||||
receiver.validate_handler_completeness()
|
receiver.validate()
|
||||||
|
|
||||||
|
|
||||||
def test_message_sending() -> None:
|
def test_message_sending() -> None:
|
||||||
@ -226,42 +269,40 @@ def test_message_sending() -> None:
|
|||||||
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
||||||
|
|
||||||
@receiver.handler
|
@receiver.handler
|
||||||
def handle_test_message_1(self, msg: _TestMessage1) -> _TestMessageR1:
|
def handle_test_message_1(self, msg: _TMessage1) -> _TResponse1:
|
||||||
"""Test."""
|
"""Test."""
|
||||||
print('Hello from test message 1 handler!')
|
|
||||||
if msg.ival == 1:
|
if msg.ival == 1:
|
||||||
raise CleanError('Testing Clean Error')
|
raise CleanError('Testing Clean Error')
|
||||||
if msg.ival == 2:
|
if msg.ival == 2:
|
||||||
raise RuntimeError('Testing Runtime Error')
|
raise RuntimeError('Testing Runtime Error')
|
||||||
return _TestMessageR1(bval=True)
|
return _TResponse1(bval=True)
|
||||||
|
|
||||||
@receiver.handler
|
@receiver.handler
|
||||||
def handle_test_message_2(
|
def handle_test_message_2(
|
||||||
self,
|
self, msg: _TMessage2) -> Union[_TResponse1, _TResponse2]:
|
||||||
msg: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
|
|
||||||
"""Test."""
|
"""Test."""
|
||||||
del msg # Unused
|
del msg # Unused
|
||||||
print('Hello from test message 2 handler!')
|
return _TResponse2(fval=1.2)
|
||||||
return _TestMessageR2(fval=1.2)
|
|
||||||
|
|
||||||
receiver.validate_handler_completeness()
|
receiver.validate()
|
||||||
|
|
||||||
obj_r = TestClassR()
|
obj_r = TestClassR()
|
||||||
obj_s = TestClassS(target=obj_r)
|
obj_s = TestClassS(target=obj_r)
|
||||||
|
|
||||||
response = obj_s.msg.send(_TestMessage1(ival=0))
|
if os.environ.get('EFRO_TEST_MESSAGE_FAST') != '1':
|
||||||
response2 = obj_s.msg.send(_TestMessage2(sval='rah'))
|
response = obj_s.msg.send(_TMessage1(ival=0))
|
||||||
assert static_type_equals(response, _TestMessageR1)
|
response2 = obj_s.msg.send(_TMessage2(sval='rah'))
|
||||||
assert isinstance(response, _TestMessageR1)
|
assert static_type_equals(response, _TResponse1)
|
||||||
assert isinstance(response2, (_TestMessageR1, _TestMessageR2))
|
assert isinstance(response, _TResponse1)
|
||||||
|
assert isinstance(response2, (_TResponse1, _TResponse2))
|
||||||
|
|
||||||
# Remote CleanErrors should come across locally as the same.
|
# Remote CleanErrors should come across locally as the same.
|
||||||
try:
|
try:
|
||||||
_response3 = obj_s.msg.send(_TestMessage1(ival=1))
|
_response3 = obj_s.msg.send(_TMessage1(ival=1))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
assert isinstance(exc, CleanError)
|
assert isinstance(exc, CleanError)
|
||||||
assert str(exc) == 'Testing Clean Error'
|
assert str(exc) == 'Testing Clean Error'
|
||||||
|
|
||||||
# Other remote errors should come across as RemoteError.
|
# Other remote errors should come across as RemoteError.
|
||||||
with pytest.raises(RemoteError):
|
with pytest.raises(RemoteError):
|
||||||
_response4 = obj_s.msg.send(_TestMessage1(ival=2))
|
_response4 = obj_s.msg.send(_TMessage1(ival=2))
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
|
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
|
||||||
Sequence)
|
Sequence, Union)
|
||||||
from efro.error import CommunicationError
|
from efro.error import CommunicationError
|
||||||
|
|
||||||
TM = TypeVar('TM', bound='MessageSender')
|
TM = TypeVar('TM', bound='MessageSender')
|
||||||
@ -35,21 +35,24 @@ class RemoteErrorType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
"""Base class for messages and their responses."""
|
"""Base class for messages."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_response_types(cls) -> List[Type[Message]]:
|
def get_response_types(cls) -> List[Type[Response]]:
|
||||||
"""Return all message types this Message can result in when sent.
|
"""Return all message types this Message can result in when sent.
|
||||||
Messages intended only for response types can leave this empty.
|
|
||||||
Note: RemoteErrorMessage is handled transparently and does not
|
Note: RemoteErrorMessage is handled transparently and does not
|
||||||
need to be specified here.
|
need to be specified here.
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class Response:
|
||||||
|
"""Base class for responses to messages."""
|
||||||
|
|
||||||
|
|
||||||
@ioprepped
|
@ioprepped
|
||||||
@dataclass
|
@dataclass
|
||||||
class RemoteErrorMessage(Message):
|
class RemoteErrorMessage(Response):
|
||||||
"""Message saying some error has occurred on the other end."""
|
"""Message saying some error has occurred on the other end."""
|
||||||
error_message: Annotated[str, IOAttrs('m')]
|
error_message: Annotated[str, IOAttrs('m')]
|
||||||
error_type: Annotated[RemoteErrorType, IOAttrs('t')]
|
error_type: Annotated[RemoteErrorType, IOAttrs('t')]
|
||||||
@ -64,14 +67,13 @@ class MessageProtocol:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
message_types: Dict[int, Type[Message]],
|
message_types: Dict[int, Union[Type[Message],
|
||||||
|
Type[Response]]],
|
||||||
type_key: Optional[str] = None,
|
type_key: Optional[str] = None,
|
||||||
preserve_clean_errors: bool = True,
|
preserve_clean_errors: bool = True,
|
||||||
log_remote_exceptions: bool = True,
|
log_remote_exceptions: bool = True,
|
||||||
trusted_client: bool = False) -> None:
|
trusted_client: bool = False) -> None:
|
||||||
"""Create a protocol with a given configuration.
|
"""Create a protocol with a given configuration.
|
||||||
Each entry for message_types should contain an ID, a message type,
|
|
||||||
and all possible response types.
|
|
||||||
|
|
||||||
If 'type_key' is provided, the message type ID is stored as the
|
If 'type_key' is provided, the message type ID is stored as the
|
||||||
provided key in the message dict; otherwise it will be stored as
|
provided key in the message dict; otherwise it will be stored as
|
||||||
@ -86,8 +88,10 @@ class MessageProtocol:
|
|||||||
be included in the RemoteError. This should only be enabled in cases
|
be included in the RemoteError. This should only be enabled in cases
|
||||||
where the client is trusted.
|
where the client is trusted.
|
||||||
"""
|
"""
|
||||||
self.message_types_by_id: Dict[int, Type[Message]] = {}
|
self.message_types_by_id: Dict[int, Union[Type[Message],
|
||||||
self.message_ids_by_type: Dict[Type[Message], int] = {}
|
Type[Response]]] = {}
|
||||||
|
self.message_ids_by_type: Dict[Union[Type[Message], Type[Response]],
|
||||||
|
int] = {}
|
||||||
for m_id, m_type in message_types.items():
|
for m_id, m_type in message_types.items():
|
||||||
|
|
||||||
# Make sure only valid message types were passed and each
|
# Make sure only valid message types were passed and each
|
||||||
@ -95,34 +99,47 @@ class MessageProtocol:
|
|||||||
assert isinstance(m_id, int)
|
assert isinstance(m_id, int)
|
||||||
assert m_id >= 0
|
assert m_id >= 0
|
||||||
assert (is_ioprepped_dataclass(m_type)
|
assert (is_ioprepped_dataclass(m_type)
|
||||||
and issubclass(m_type, Message))
|
and issubclass(m_type, (Message, Response)))
|
||||||
assert self.message_types_by_id.get(m_id) is None
|
assert self.message_types_by_id.get(m_id) is None
|
||||||
|
|
||||||
self.message_types_by_id[m_id] = m_type
|
self.message_types_by_id[m_id] = m_type
|
||||||
self.message_ids_by_type[m_type] = m_id
|
self.message_ids_by_type[m_type] = m_id
|
||||||
|
|
||||||
# Make sure all return types are valid and have been assigned
|
# Some extra-thorough validation in debug mode.
|
||||||
# an ID as well.
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
all_response_types: Set[Type[Message]] = set()
|
# Make sure all return types are valid and have been assigned
|
||||||
|
# an ID as well.
|
||||||
|
all_response_types: Set[Type[Response]] = set()
|
||||||
for m_id, m_type in message_types.items():
|
for m_id, m_type in message_types.items():
|
||||||
m_rtypes = m_type.get_response_types()
|
if issubclass(m_type, Message):
|
||||||
assert isinstance(m_rtypes, list)
|
m_rtypes = m_type.get_response_types()
|
||||||
assert len(set(m_rtypes)) == len(m_rtypes) # check for dups
|
assert isinstance(m_rtypes, list)
|
||||||
all_response_types.update(m_rtypes)
|
assert m_rtypes # make sure not empty
|
||||||
|
assert len(set(m_rtypes)) == len(m_rtypes) # check dups
|
||||||
|
all_response_types.update(m_rtypes)
|
||||||
for cls in all_response_types:
|
for cls in all_response_types:
|
||||||
assert is_ioprepped_dataclass(cls) and issubclass(cls, Message)
|
assert is_ioprepped_dataclass(cls) and issubclass(
|
||||||
|
cls, (Message, Response))
|
||||||
if cls not in self.message_ids_by_type:
|
if cls not in self.message_ids_by_type:
|
||||||
raise ValueError(f'Possible response type {cls}'
|
raise ValueError(f'Possible response type {cls}'
|
||||||
f' was not included in message_types.')
|
f' was not included in message_types.')
|
||||||
|
|
||||||
|
# Make sure all registered types have unique base names.
|
||||||
|
# We can take advantage of this to generate cleaner looking
|
||||||
|
# protocol modules. Can revisit if this is ever a problem.
|
||||||
|
mtypenames = set(tp.__name__ for tp in self.message_ids_by_type)
|
||||||
|
if len(mtypenames) != len(message_types):
|
||||||
|
raise ValueError(
|
||||||
|
'message_types contains duplicate __name__s;'
|
||||||
|
' all types are required to have unique names.')
|
||||||
|
|
||||||
self._type_key = type_key
|
self._type_key = type_key
|
||||||
self.preserve_clean_errors = preserve_clean_errors
|
self.preserve_clean_errors = preserve_clean_errors
|
||||||
self.log_remote_exceptions = log_remote_exceptions
|
self.log_remote_exceptions = log_remote_exceptions
|
||||||
self.trusted_client = trusted_client
|
self.trusted_client = trusted_client
|
||||||
|
|
||||||
def message_encode(self,
|
def message_encode(self,
|
||||||
message: Message,
|
message: Union[Message, Response],
|
||||||
is_error: bool = False) -> bytes:
|
is_error: bool = False) -> bytes:
|
||||||
"""Encode a message to bytes for transport."""
|
"""Encode a message to bytes for transport."""
|
||||||
|
|
||||||
@ -132,8 +149,9 @@ class MessageProtocol:
|
|||||||
else:
|
else:
|
||||||
m_id = self.message_ids_by_type.get(type(message))
|
m_id = self.message_ids_by_type.get(type(message))
|
||||||
if m_id is None:
|
if m_id is None:
|
||||||
raise TypeError(f'Message type is not registered in Protocol:'
|
raise TypeError(
|
||||||
f' {type(message)}')
|
f'Message/Response type is not registered in Protocol:'
|
||||||
|
f' {type(message)}')
|
||||||
msgdict = dataclass_to_dict(message)
|
msgdict = dataclass_to_dict(message)
|
||||||
|
|
||||||
# Encode type as part of the message dict if desired
|
# Encode type as part of the message dict if desired
|
||||||
@ -148,7 +166,7 @@ class MessageProtocol:
|
|||||||
out = {'m': msgdict, 't': m_id}
|
out = {'m': msgdict, 't': m_id}
|
||||||
return json.dumps(out, separators=(',', ':')).encode()
|
return json.dumps(out, separators=(',', ':')).encode()
|
||||||
|
|
||||||
def message_decode(self, data: bytes) -> Message:
|
def message_decode(self, data: bytes) -> Union[Message, Response]:
|
||||||
"""Decode a message from bytes.
|
"""Decode a message from bytes.
|
||||||
|
|
||||||
If the message represents a remote error, an Exception will
|
If the message represents a remote error, an Exception will
|
||||||
@ -178,25 +196,200 @@ class MessageProtocol:
|
|||||||
# Decode this particular type and make sure its valid.
|
# Decode this particular type and make sure its valid.
|
||||||
msgtype = self.message_types_by_id.get(m_id)
|
msgtype = self.message_types_by_id.get(m_id)
|
||||||
if msgtype is None:
|
if msgtype is None:
|
||||||
raise TypeError(f'Got unregistered message type id of {m_id}.')
|
raise TypeError(
|
||||||
|
f'Got unregistered message/response type id of {m_id}.')
|
||||||
|
|
||||||
return dataclass_from_dict(msgtype, msgdict)
|
out = dataclass_from_dict(msgtype, msgdict)
|
||||||
|
assert isinstance(out, (Message, Response))
|
||||||
|
return out
|
||||||
|
|
||||||
def create_sender_module(self, classname: str) -> str:
|
def _get_module_header(self, part: str) -> str:
|
||||||
|
"""Return common parts of generated modules."""
|
||||||
|
imports: Dict[str, List[str]] = {}
|
||||||
|
for msgtype in self.message_ids_by_type:
|
||||||
|
imports.setdefault(msgtype.__module__, []).append(msgtype.__name__)
|
||||||
|
importlines = ''
|
||||||
|
for module, names in sorted(imports.items()):
|
||||||
|
jnames = ', '.join(names)
|
||||||
|
line = f'from {module} import {jnames}'
|
||||||
|
if len(line) > 79:
|
||||||
|
# Recreate in a wrapping-friendly form.
|
||||||
|
line = f'from {module} import ({jnames})'
|
||||||
|
importlines += f'{line}\n'
|
||||||
|
|
||||||
|
if part == 'sender':
|
||||||
|
importlines = (
|
||||||
|
f'from efro.message import MessageSender\n{importlines}')
|
||||||
|
tpimports = 'from efro.message import Message, Response'
|
||||||
|
else:
|
||||||
|
importlines = (
|
||||||
|
f'from efro.message import MessageReceiver\n{importlines}')
|
||||||
|
tpimports = 'from efro.message import Message, Response'
|
||||||
|
|
||||||
|
out = ('# Released under the MIT License. See LICENSE for details.\n'
|
||||||
|
f'#\n'
|
||||||
|
f'"""Auto-generated {part} module."""\n'
|
||||||
|
f'\n'
|
||||||
|
f'from __future__ import annotations\n'
|
||||||
|
f'\n'
|
||||||
|
f'from typing import TYPE_CHECKING, overload\n'
|
||||||
|
f'\n'
|
||||||
|
f'{importlines}'
|
||||||
|
f'\n'
|
||||||
|
f'if TYPE_CHECKING:\n'
|
||||||
|
f' from typing import Union\n'
|
||||||
|
f' {tpimports}\n'
|
||||||
|
f'\n'
|
||||||
|
f'\n')
|
||||||
|
return out
|
||||||
|
|
||||||
|
def create_sender_module(self,
|
||||||
|
classname: str,
|
||||||
|
private: bool = False) -> str:
|
||||||
""""Create a Python module defining a MessageSender subclass.
|
""""Create a Python module defining a MessageSender subclass.
|
||||||
|
|
||||||
This class is primarily for type checking and will contain overrides
|
This class is primarily for type checking and will contain overrides
|
||||||
for the varieties of send calls for message/response types defined
|
for the varieties of send calls for message/response types defined
|
||||||
in the protocol.
|
in the protocol.
|
||||||
|
|
||||||
|
Note that line lengths are not clipped, so output may need to be
|
||||||
|
run through a formatter to prevent lint warnings about excessive
|
||||||
|
line lengths.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_receiver_module(self, classname: str) -> str:
|
ppre = '_' if private else ''
|
||||||
|
out = self._get_module_header('sender')
|
||||||
|
out += (f'class {ppre}{classname}MessageSender(MessageSender):\n'
|
||||||
|
f' """Protocol-specific sender."""\n'
|
||||||
|
f'\n'
|
||||||
|
f' def __get__(self,\n'
|
||||||
|
f' obj: Any,\n'
|
||||||
|
f' type_in: Any = None)'
|
||||||
|
f' -> {ppre}Bound{classname}MessageSender:\n'
|
||||||
|
f' return {ppre}Bound{classname}MessageSender'
|
||||||
|
f'(obj, self)\n'
|
||||||
|
f'\n'
|
||||||
|
f'\n'
|
||||||
|
f'class {ppre}Bound{classname}MessageSender:\n'
|
||||||
|
f' """Protocol-specific bound sender."""\n'
|
||||||
|
f'\n'
|
||||||
|
f' def __init__(self, obj: Any,'
|
||||||
|
f' sender: {ppre}{classname}MessageSender) -> None:\n'
|
||||||
|
f' assert obj is not None\n'
|
||||||
|
f' self._obj = obj\n'
|
||||||
|
f' self._sender = sender\n')
|
||||||
|
|
||||||
|
# Define handler() overloads for all registered message types.
|
||||||
|
msgtypes = [
|
||||||
|
t for t in self.message_ids_by_type if issubclass(t, Message)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Ew; @overload requires at least 2 different signatures so
|
||||||
|
# we need to simply write a single function if we have < 2.
|
||||||
|
if len(msgtypes) == 1:
|
||||||
|
raise RuntimeError('FIXME: currently require at least 2'
|
||||||
|
' message types.')
|
||||||
|
if len(msgtypes) > 1:
|
||||||
|
for msgtype in msgtypes:
|
||||||
|
msgtypevar = msgtype.__name__
|
||||||
|
rtypes = msgtype.get_response_types()
|
||||||
|
if len(rtypes) > 1:
|
||||||
|
tps = ', '.join(t.__name__ for t in rtypes)
|
||||||
|
responsetypevar = f'Union[{tps}]'
|
||||||
|
else:
|
||||||
|
responsetypevar = rtypes[0].__name__
|
||||||
|
out += (f'\n'
|
||||||
|
f' @overload\n'
|
||||||
|
f' def send(self, message: {msgtypevar})'
|
||||||
|
f' -> {responsetypevar}:\n'
|
||||||
|
f' ...\n')
|
||||||
|
out += ('\n'
|
||||||
|
' def send(self, message: Message) -> Response:\n'
|
||||||
|
' """Send a message."""\n'
|
||||||
|
' return self._sender.send(self._obj, message)\n')
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def create_receiver_module(self,
|
||||||
|
classname: str,
|
||||||
|
private: bool = False) -> str:
|
||||||
""""Create a Python module defining a MessageReceiver subclass.
|
""""Create a Python module defining a MessageReceiver subclass.
|
||||||
|
|
||||||
This class is primarily for type checking and will contain overrides
|
This class is primarily for type checking and will contain overrides
|
||||||
for the register method for message/response types defined in
|
for the register method for message/response types defined in
|
||||||
the protocol.
|
the protocol.
|
||||||
|
|
||||||
|
Note that line lengths are not clipped, so output may need to be
|
||||||
|
run through a formatter to prevent lint warnings about excessive
|
||||||
|
line lengths.
|
||||||
"""
|
"""
|
||||||
|
ppre = '_' if private else ''
|
||||||
|
out = self._get_module_header('receiver')
|
||||||
|
out += (f'class {ppre}{classname}MessageReceiver(MessageReceiver):\n'
|
||||||
|
f' """Protocol-specific receiver."""\n'
|
||||||
|
f'\n'
|
||||||
|
f' def __get__(\n'
|
||||||
|
f' self,\n'
|
||||||
|
f' obj: Any,\n'
|
||||||
|
f' type_in: Any = None,\n'
|
||||||
|
f' ) -> {ppre}Bound{classname}MessageReceiver:\n'
|
||||||
|
f' return {ppre}Bound{classname}MessageReceiver('
|
||||||
|
f'obj, self)\n')
|
||||||
|
|
||||||
|
# Define handler() overloads for all registered message types.
|
||||||
|
msgtypes = [
|
||||||
|
t for t in self.message_ids_by_type if issubclass(t, Message)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Ew; @overload requires at least 2 different signatures so
|
||||||
|
# we need to simply write a single function if we have < 2.
|
||||||
|
if len(msgtypes) == 1:
|
||||||
|
raise RuntimeError('FIXME: currently require at least 2'
|
||||||
|
' message types.')
|
||||||
|
if len(msgtypes) > 1:
|
||||||
|
for msgtype in msgtypes:
|
||||||
|
msgtypevar = msgtype.__name__
|
||||||
|
rtypes = msgtype.get_response_types()
|
||||||
|
if len(rtypes) > 1:
|
||||||
|
tps = ', '.join(t.__name__ for t in rtypes)
|
||||||
|
rtypevar = f'Union[{tps}]'
|
||||||
|
else:
|
||||||
|
rtypevar = rtypes[0].__name__
|
||||||
|
out += (
|
||||||
|
f'\n'
|
||||||
|
f' @overload\n'
|
||||||
|
f' def handler(\n'
|
||||||
|
f' self,\n'
|
||||||
|
f' call: Callable[[Any, {msgtypevar}], '
|
||||||
|
f'{rtypevar}],\n'
|
||||||
|
f' ) -> Callable[[Any, {msgtypevar}], {rtypevar}]:\n'
|
||||||
|
f' ...\n')
|
||||||
|
out += ('\n'
|
||||||
|
' def handler(self, call: Callable) -> Callable:\n'
|
||||||
|
' """Decorator to register message handlers."""\n'
|
||||||
|
' self.register_handler(call)\n'
|
||||||
|
' return call\n')
|
||||||
|
|
||||||
|
out += (f'\n'
|
||||||
|
f'\n'
|
||||||
|
f'class {ppre}Bound{classname}MessageReceiver:\n'
|
||||||
|
f' """Protocol-specific bound receiver."""\n'
|
||||||
|
f'\n'
|
||||||
|
f' def __init__(\n'
|
||||||
|
f' self,\n'
|
||||||
|
f' obj: Any,\n'
|
||||||
|
f' receiver: _TestMessageReceiver,\n'
|
||||||
|
f' ) -> None:\n'
|
||||||
|
f' assert obj is not None\n'
|
||||||
|
f' self._obj = obj\n'
|
||||||
|
f' self._receiver = receiver\n'
|
||||||
|
f'\n'
|
||||||
|
f' def handle_raw_message(self, message: bytes) -> bytes:\n'
|
||||||
|
f' """Handle a raw incoming message."""\n'
|
||||||
|
f' return self._receiver.handle_raw_message'
|
||||||
|
f'(self._obj, message)\n')
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class MessageSender:
|
class MessageSender:
|
||||||
@ -232,7 +425,7 @@ class MessageSender:
|
|||||||
self._send_raw_message_call = call
|
self._send_raw_message_call = call
|
||||||
return call
|
return call
|
||||||
|
|
||||||
def send(self, bound_obj: Any, message: Message) -> Message:
|
def send(self, bound_obj: Any, message: Message) -> Response:
|
||||||
"""Send a message and receive a response.
|
"""Send a message and receive a response.
|
||||||
|
|
||||||
Will encode the message for transport and call dispatch_raw_message()
|
Will encode the message for transport and call dispatch_raw_message()
|
||||||
@ -240,13 +433,10 @@ class MessageSender:
|
|||||||
if self._send_raw_message_call is None:
|
if self._send_raw_message_call is None:
|
||||||
raise RuntimeError('send() is unimplemented for this type.')
|
raise RuntimeError('send() is unimplemented for this type.')
|
||||||
|
|
||||||
# Only types with possible response types should ever be sent.
|
|
||||||
assert type(message).get_response_types()
|
|
||||||
|
|
||||||
msg_encoded = self._protocol.message_encode(message)
|
msg_encoded = self._protocol.message_encode(message)
|
||||||
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
|
response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
|
||||||
response = self._protocol.message_decode(response_encoded)
|
response = self._protocol.message_decode(response_encoded)
|
||||||
assert isinstance(response, Message)
|
assert isinstance(response, Response)
|
||||||
assert type(response) in type(message).get_response_types()
|
assert type(response) in type(message).get_response_types()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -278,10 +468,10 @@ class MessageReceiver:
|
|||||||
class MyClass:
|
class MyClass:
|
||||||
receiver = MyMessageReceiver()
|
receiver = MyMessageReceiver()
|
||||||
|
|
||||||
# MyMessageReceiver should provide overloads to register_handler()
|
# MyMessageReceiver fills out handler() overloads to ensure all
|
||||||
# to ensure all registered handlers have valid types/return-types.
|
# registered handlers have valid types/return-types.
|
||||||
@receiver.handler
|
@receiver.handler
|
||||||
def handle_some_message_type(self, message: SomeType) -> AnotherType:
|
def handle_some_message_type(self, message: SomeMsg) -> SomeResponse:
|
||||||
# Deal with this message type here.
|
# Deal with this message type here.
|
||||||
|
|
||||||
# This will trigger the registered handler being called.
|
# This will trigger the registered handler being called.
|
||||||
@ -298,7 +488,7 @@ class MessageReceiver:
|
|||||||
|
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
def register_handler(self, call: Callable[[Any, Message],
|
def register_handler(self, call: Callable[[Any, Message],
|
||||||
Message]) -> None:
|
Response]) -> None:
|
||||||
"""Register a handler call.
|
"""Register a handler call.
|
||||||
|
|
||||||
The message type handled by the call is determined by its
|
The message type handled by the call is determined by its
|
||||||
@ -366,14 +556,10 @@ class MessageReceiver:
|
|||||||
# Ok; we're good!
|
# Ok; we're good!
|
||||||
self._handlers[msgtype] = call
|
self._handlers[msgtype] = call
|
||||||
|
|
||||||
def validate_handler_completeness(self, warn_only: bool = False) -> None:
|
def validate(self, warn_only: bool = False) -> None:
|
||||||
"""Return whether this receiver handles all protocol messages.
|
"""Check for handler completeness, valid types, etc."""
|
||||||
|
|
||||||
Only messages having possible response types are considered, as
|
|
||||||
those are the only ones that can be sent to a receiver.
|
|
||||||
"""
|
|
||||||
for msgtype in self._protocol.message_ids_by_type.keys():
|
for msgtype in self._protocol.message_ids_by_type.keys():
|
||||||
if not msgtype.get_response_types():
|
if issubclass(msgtype, Response):
|
||||||
continue
|
continue
|
||||||
if msgtype not in self._handlers:
|
if msgtype not in self._handlers:
|
||||||
msg = (f'Protocol message {msgtype} not handled'
|
msg = (f'Protocol message {msgtype} not handled'
|
||||||
@ -388,6 +574,7 @@ class MessageReceiver:
|
|||||||
# Decode the incoming message.
|
# Decode the incoming message.
|
||||||
msg_decoded = self._protocol.message_decode(msg)
|
msg_decoded = self._protocol.message_decode(msg)
|
||||||
msgtype = type(msg_decoded)
|
msgtype = type(msg_decoded)
|
||||||
|
assert issubclass(msgtype, Message)
|
||||||
|
|
||||||
# Call the proper handler.
|
# Call the proper handler.
|
||||||
handler = self._handlers.get(msgtype)
|
handler = self._handlers.get(msgtype)
|
||||||
@ -396,7 +583,7 @@ class MessageReceiver:
|
|||||||
response = handler(bound_obj, msg_decoded)
|
response = handler(bound_obj, msg_decoded)
|
||||||
|
|
||||||
# Re-encode the response.
|
# Re-encode the response.
|
||||||
assert isinstance(response, Message)
|
assert isinstance(response, Response)
|
||||||
assert type(response) in msgtype.get_response_types()
|
assert type(response) in msgtype.get_response_types()
|
||||||
return self._protocol.message_encode(response)
|
return self._protocol.message_encode(response)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user