mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-01-27 17:33:13 +08:00
209 lines
5.5 KiB
Python
209 lines
5.5 KiB
Python
# Released under the MIT License. See LICENSE for details.
|
|
#
|
|
"""Testing message functionality."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, overload
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
|
|
from efro.dataclassio import ioprepped
|
|
from efro.message import (Message, MessageProtocol, MessageSender,
|
|
MessageReceiver)
|
|
# from efrotools.statictest import static_type_equals
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import List, Type, Any, Callable, Union
|
|
|
|
|
|
@ioprepped
|
|
@dataclass
|
|
class _TestMessage1(Message):
|
|
"""Just testing."""
|
|
ival: int
|
|
|
|
@classmethod
|
|
def get_response_types(cls) -> List[Type[Message]]:
|
|
return [_TestMessageR1]
|
|
|
|
|
|
@ioprepped
|
|
@dataclass
|
|
class _TestMessage2(Message):
|
|
"""Just testing."""
|
|
sval: str
|
|
|
|
@classmethod
|
|
def get_response_types(cls) -> List[Type[Message]]:
|
|
return [_TestMessageR1, _TestMessageR2]
|
|
|
|
|
|
@ioprepped
|
|
@dataclass
|
|
class _TestMessageR1(Message):
|
|
"""Just testing."""
|
|
bval: bool
|
|
|
|
|
|
@ioprepped
|
|
@dataclass
|
|
class _TestMessageR2(Message):
|
|
"""Just testing."""
|
|
fval: float
|
|
|
|
|
|
@ioprepped
|
|
@dataclass
|
|
class _TestMessageR3(Message):
|
|
"""Just testing."""
|
|
fval: float
|
|
|
|
|
|
class _TestMessageSender(MessageSender):
|
|
"""Testing type overrides for message sending.
|
|
Normally this would be auto-generated based on the protocol.
|
|
"""
|
|
|
|
def __get__(self,
|
|
obj: Any,
|
|
type_in: Any = None) -> _BoundTestMessageSender:
|
|
return _BoundTestMessageSender(obj, self)
|
|
|
|
|
|
class _BoundTestMessageSender:
|
|
"""Testing type overrides for message sending.
|
|
Normally this would be auto-generated based on the protocol.
|
|
"""
|
|
|
|
def __init__(self, obj: Any, sender: _TestMessageSender) -> None:
|
|
assert obj is not None
|
|
self._obj = obj
|
|
self._sender = sender
|
|
|
|
@overload
|
|
def send(self, message: _TestMessage1) -> _TestMessageR1:
|
|
...
|
|
|
|
@overload
|
|
def send(self,
|
|
message: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
|
|
...
|
|
|
|
def send(self, message: Any) -> Any:
|
|
"""Send a particular message type."""
|
|
return self._sender.send(self._obj, message)
|
|
|
|
|
|
class _TestMessageReceiver(MessageReceiver):
|
|
"""Testing type overrides for message receiving.
|
|
Normally this would be auto-generated based on the protocol.
|
|
"""
|
|
|
|
def __get__(self,
|
|
obj: Any,
|
|
type_in: Any = None) -> _BoundTestMessageReceiver:
|
|
return _BoundTestMessageReceiver(obj, self)
|
|
|
|
@overload
|
|
def handler(
|
|
self, call: Callable[[Any, _TestMessage1], _TestMessageR1]
|
|
) -> Callable[[Any, _TestMessage1], _TestMessageR1]:
|
|
...
|
|
|
|
@overload
|
|
def handler(
|
|
self, call: Callable[[Any, _TestMessage2], Union[_TestMessageR1,
|
|
_TestMessageR2]]
|
|
) -> Callable[[Any, _TestMessage2], Union[_TestMessageR1, _TestMessageR2]]:
|
|
...
|
|
|
|
def handler(self, call: Callable) -> Callable:
|
|
"""Decorator to register a handler for a particular message type."""
|
|
self.register_handler(call)
|
|
return call
|
|
|
|
|
|
class _BoundTestMessageReceiver:
|
|
"""Testing type overrides for message receiving.
|
|
Normally this would be auto-generated based on the protocol.
|
|
"""
|
|
|
|
def __init__(self, obj: Any, receiver: _TestMessageReceiver) -> None:
|
|
assert obj is not None
|
|
self._obj = obj
|
|
self._receiver = receiver
|
|
|
|
|
|
TEST_PROTOCOL = MessageProtocol(message_types={
|
|
1: _TestMessage1,
|
|
2: _TestMessage2,
|
|
3: _TestMessageR1,
|
|
4: _TestMessageR2,
|
|
})
|
|
|
|
|
|
def test_protocol_creation() -> None:
|
|
"""Test protocol creation."""
|
|
|
|
# This should fail because _TestMessage1 can return _TestMessageR1 which
|
|
# is not given an id here.
|
|
with pytest.raises(ValueError):
|
|
_protocol = MessageProtocol(message_types={1: _TestMessage1})
|
|
|
|
# Now it should work.
|
|
_protocol = MessageProtocol(message_types={
|
|
1: _TestMessage1,
|
|
2: _TestMessageR1
|
|
})
|
|
|
|
|
|
def test_message_sending() -> None:
|
|
"""Test simple message sending."""
|
|
|
|
# Define a class that can send messages and one that can receive them.
|
|
class TestClassS:
|
|
"""Test class incorporating send functionality."""
|
|
|
|
msg = _TestMessageSender(TEST_PROTOCOL)
|
|
|
|
def __init__(self, receiver: TestClassR) -> None:
|
|
self._receiver = receiver
|
|
|
|
@msg.send_raw_handler
|
|
def _send_raw_message(self, data: bytes) -> bytes:
|
|
"""Test."""
|
|
print(f'WOULD SEND RAW MSG OF SIZE: {len(data)}')
|
|
return b''
|
|
|
|
class TestClassR:
|
|
"""Test class incorporating receive functionality."""
|
|
|
|
receiver = _TestMessageReceiver(TEST_PROTOCOL)
|
|
|
|
@receiver.handler
|
|
def handle_test_message_1(self, msg: _TestMessage1) -> _TestMessageR1:
|
|
"""Test."""
|
|
del msg # Unused
|
|
print('Hello from test message 1 handler!')
|
|
return _TestMessageR1(bval=True)
|
|
|
|
@receiver.handler
|
|
def handle_test_message_2(
|
|
self,
|
|
msg: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
|
|
"""Test."""
|
|
del msg # Unused
|
|
print('Hello from test message 1 handler!')
|
|
return _TestMessageR2(fval=1.2)
|
|
|
|
obj_r = TestClassR()
|
|
obj_s = TestClassS(receiver=obj_r)
|
|
|
|
_result = obj_s.msg.send(_TestMessage1(ival=0))
|
|
_result2 = obj_s.msg.send(_TestMessage2(sval='rah'))
|
|
print('SKIPPING STATIC CHECK')
|
|
# assert static_type_equals(result, _TestMessageR1)
|
|
# assert isinstance(result, _TestMessageR1)
|