ballistica/tests/test_efro/test_message.py

209 lines
5.5 KiB
Python

# Released under the MIT License. See LICENSE for details.
#
"""Testing message functionality."""
from __future__ import annotations
from typing import TYPE_CHECKING, overload
from dataclasses import dataclass
import pytest
from efro.dataclassio import ioprepped
from efro.message import (Message, MessageProtocol, MessageSender,
MessageReceiver)
# from efrotools.statictest import static_type_equals
if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union
@ioprepped
@dataclass
class _TestMessage1(Message):
"""Just testing."""
ival: int
@classmethod
def get_response_types(cls) -> List[Type[Message]]:
return [_TestMessageR1]
@ioprepped
@dataclass
class _TestMessage2(Message):
"""Just testing."""
sval: str
@classmethod
def get_response_types(cls) -> List[Type[Message]]:
return [_TestMessageR1, _TestMessageR2]
@ioprepped
@dataclass
class _TestMessageR1(Message):
"""Just testing."""
bval: bool
@ioprepped
@dataclass
class _TestMessageR2(Message):
"""Just testing."""
fval: float
@ioprepped
@dataclass
class _TestMessageR3(Message):
"""Just testing."""
fval: float
class _TestMessageSender(MessageSender):
"""Testing type overrides for message sending.
Normally this would be auto-generated based on the protocol.
"""
def __get__(self,
obj: Any,
type_in: Any = None) -> _BoundTestMessageSender:
return _BoundTestMessageSender(obj, self)
class _BoundTestMessageSender:
"""Testing type overrides for message sending.
Normally this would be auto-generated based on the protocol.
"""
def __init__(self, obj: Any, sender: _TestMessageSender) -> None:
assert obj is not None
self._obj = obj
self._sender = sender
@overload
def send(self, message: _TestMessage1) -> _TestMessageR1:
...
@overload
def send(self,
message: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
...
def send(self, message: Any) -> Any:
"""Send a particular message type."""
return self._sender.send(self._obj, message)
class _TestMessageReceiver(MessageReceiver):
"""Testing type overrides for message receiving.
Normally this would be auto-generated based on the protocol.
"""
def __get__(self,
obj: Any,
type_in: Any = None) -> _BoundTestMessageReceiver:
return _BoundTestMessageReceiver(obj, self)
@overload
def handler(
self, call: Callable[[Any, _TestMessage1], _TestMessageR1]
) -> Callable[[Any, _TestMessage1], _TestMessageR1]:
...
@overload
def handler(
self, call: Callable[[Any, _TestMessage2], Union[_TestMessageR1,
_TestMessageR2]]
) -> Callable[[Any, _TestMessage2], Union[_TestMessageR1, _TestMessageR2]]:
...
def handler(self, call: Callable) -> Callable:
"""Decorator to register a handler for a particular message type."""
self.register_handler(call)
return call
class _BoundTestMessageReceiver:
"""Testing type overrides for message receiving.
Normally this would be auto-generated based on the protocol.
"""
def __init__(self, obj: Any, receiver: _TestMessageReceiver) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
TEST_PROTOCOL = MessageProtocol(message_types={
1: _TestMessage1,
2: _TestMessage2,
3: _TestMessageR1,
4: _TestMessageR2,
})
def test_protocol_creation() -> None:
"""Test protocol creation."""
# This should fail because _TestMessage1 can return _TestMessageR1 which
# is not given an id here.
with pytest.raises(ValueError):
_protocol = MessageProtocol(message_types={1: _TestMessage1})
# Now it should work.
_protocol = MessageProtocol(message_types={
1: _TestMessage1,
2: _TestMessageR1
})
def test_message_sending() -> None:
"""Test simple message sending."""
# Define a class that can send messages and one that can receive them.
class TestClassS:
"""Test class incorporating send functionality."""
msg = _TestMessageSender(TEST_PROTOCOL)
def __init__(self, receiver: TestClassR) -> None:
self._receiver = receiver
@msg.send_raw_handler
def _send_raw_message(self, data: bytes) -> bytes:
"""Test."""
print(f'WOULD SEND RAW MSG OF SIZE: {len(data)}')
return b''
class TestClassR:
"""Test class incorporating receive functionality."""
receiver = _TestMessageReceiver(TEST_PROTOCOL)
@receiver.handler
def handle_test_message_1(self, msg: _TestMessage1) -> _TestMessageR1:
"""Test."""
del msg # Unused
print('Hello from test message 1 handler!')
return _TestMessageR1(bval=True)
@receiver.handler
def handle_test_message_2(
self,
msg: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
"""Test."""
del msg # Unused
print('Hello from test message 1 handler!')
return _TestMessageR2(fval=1.2)
obj_r = TestClassR()
obj_s = TestClassS(receiver=obj_r)
_result = obj_s.msg.send(_TestMessage1(ival=0))
_result2 = obj_s.msg.send(_TestMessage2(sval='rah'))
print('SKIPPING STATIC CHECK')
# assert static_type_equals(result, _TestMessageR1)
# assert isinstance(result, _TestMessageR1)