diff --git a/.idea/dictionaries/ericf.xml b/.idea/dictionaries/ericf.xml
index 17471111..896fa742 100644
--- a/.idea/dictionaries/ericf.xml
+++ b/.idea/dictionaries/ericf.xml
@@ -715,6 +715,7 @@
existables
expatbuilder
expatreader
+ expectedsig
explodable
explodey
exportoptions
@@ -1403,6 +1404,7 @@
mrmaxmeier
msbuild
msgdict
+ msgfull
msgtype
mshell
msvccompiler
diff --git a/ballisticacore-cmake/.idea/dictionaries/ericf.xml b/ballisticacore-cmake/.idea/dictionaries/ericf.xml
index 64cfb863..4192b9ee 100644
--- a/ballisticacore-cmake/.idea/dictionaries/ericf.xml
+++ b/ballisticacore-cmake/.idea/dictionaries/ericf.xml
@@ -332,6 +332,7 @@
exhash
exhashstr
expbool
+ expectedsig
expl
extrahash
extrascale
@@ -640,6 +641,7 @@
mqrspec
msaa
msgdict
+ msgfull
msgtype
mult
multing
diff --git a/docs/ba_module.md b/docs/ba_module.md
index 44eaafed..0a9694a2 100644
--- a/docs/ba_module.md
+++ b/docs/ba_module.md
@@ -1,5 +1,5 @@
-
last updated on 2021-09-05 for Ballistica version 1.6.5 build 20391
+last updated on 2021-09-07 for Ballistica version 1.6.5 build 20391
This page documents the Python classes and functions in the 'ba' module,
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please let me know. Happy modding!
diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py
index 418903ef..17554bf6 100644
--- a/tests/test_efro/test_message.py
+++ b/tests/test_efro/test_message.py
@@ -9,10 +9,11 @@ from dataclasses import dataclass
import pytest
+from efro.error import CleanError, RemoteError
from efro.dataclassio import ioprepped
from efro.message import (Message, MessageProtocol, MessageSender,
MessageReceiver)
-# from efrotools.statictest import static_type_equals
+from efrotools.statictest import static_type_equals
if TYPE_CHECKING:
from typing import List, Type, Any, Callable, Union
@@ -63,6 +64,7 @@ class _TestMessageR3(Message):
class _TestMessageSender(MessageSender):
"""Testing type overrides for message sending.
+
Normally this would be auto-generated based on the protocol.
"""
@@ -74,6 +76,7 @@ class _TestMessageSender(MessageSender):
class _BoundTestMessageSender:
"""Testing type overrides for message sending.
+
Normally this would be auto-generated based on the protocol.
"""
@@ -91,13 +94,14 @@ class _BoundTestMessageSender:
message: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
...
- def send(self, message: Any) -> Any:
+ def send(self, message: Message) -> Message:
"""Send a particular message type."""
return self._sender.send(self._obj, message)
class _TestMessageReceiver(MessageReceiver):
"""Testing type overrides for message receiving.
+
Normally this would be auto-generated based on the protocol.
"""
@@ -127,6 +131,7 @@ class _TestMessageReceiver(MessageReceiver):
class _BoundTestMessageReceiver:
"""Testing type overrides for message receiving.
+
Normally this would be auto-generated based on the protocol.
"""
@@ -135,13 +140,21 @@ class _BoundTestMessageReceiver:
self._obj = obj
self._receiver = receiver
+ def handle_raw_message(self, message: bytes) -> bytes:
+ """Handle a raw incoming message."""
+ return self._receiver.handle_raw_message(self._obj, message)
-TEST_PROTOCOL = MessageProtocol(message_types={
- 1: _TestMessage1,
- 2: _TestMessage2,
- 3: _TestMessageR1,
- 4: _TestMessageR2,
-})
+
+TEST_PROTOCOL = MessageProtocol(
+ message_types={
+ 1: _TestMessage1,
+ 2: _TestMessage2,
+ 3: _TestMessageR1,
+ 4: _TestMessageR2,
+ },
+ trusted_client=True,
+ log_remote_exceptions=False,
+)
def test_protocol_creation() -> None:
@@ -159,6 +172,37 @@ def test_protocol_creation() -> None:
})
+def test_receiver_creation() -> None:
+ """Test receiver creation"""
+
+ # This should fail due to the registered handler only specifying
+ # one response message type while the message type itself
+ # specifies two.
+ with pytest.raises(TypeError):
+
+ class _TestClassR:
+ """Test class incorporating receive functionality."""
+
+ receiver = _TestMessageReceiver(TEST_PROTOCOL)
+
+ @receiver.handler
+ def handle_test_message_2(self,
+ msg: _TestMessage2) -> _TestMessageR2:
+ """Test."""
+ del msg # Unused
+ print('Hello from test message 1 handler!')
+ return _TestMessageR2(fval=1.2)
+
+ # Should fail because not all message types in the protocol are handled.
+ with pytest.raises(TypeError):
+
+ class _TestClassR2:
+ """Test class incorporating receive functionality."""
+
+ receiver = _TestMessageReceiver(TEST_PROTOCOL)
+ receiver.validate_handler_completeness()
+
+
def test_message_sending() -> None:
"""Test simple message sending."""
@@ -168,14 +212,13 @@ def test_message_sending() -> None:
msg = _TestMessageSender(TEST_PROTOCOL)
- def __init__(self, receiver: TestClassR) -> None:
- self._receiver = receiver
+ def __init__(self, target: TestClassR) -> None:
+ self._target = target
@msg.send_raw_handler
def _send_raw_message(self, data: bytes) -> bytes:
"""Test."""
- print(f'WOULD SEND RAW MSG OF SIZE: {len(data)}')
- return b''
+ return self._target.receiver.handle_raw_message(data)
class TestClassR:
"""Test class incorporating receive functionality."""
@@ -185,8 +228,11 @@ def test_message_sending() -> None:
@receiver.handler
def handle_test_message_1(self, msg: _TestMessage1) -> _TestMessageR1:
"""Test."""
- del msg # Unused
print('Hello from test message 1 handler!')
+ if msg.ival == 1:
+ raise CleanError('Testing Clean Error')
+ if msg.ival == 2:
+ raise RuntimeError('Testing Runtime Error')
return _TestMessageR1(bval=True)
@receiver.handler
@@ -195,14 +241,27 @@ def test_message_sending() -> None:
msg: _TestMessage2) -> Union[_TestMessageR1, _TestMessageR2]:
"""Test."""
del msg # Unused
- print('Hello from test message 1 handler!')
+ print('Hello from test message 2 handler!')
return _TestMessageR2(fval=1.2)
- obj_r = TestClassR()
- obj_s = TestClassS(receiver=obj_r)
+ receiver.validate_handler_completeness()
- _result = obj_s.msg.send(_TestMessage1(ival=0))
- _result2 = obj_s.msg.send(_TestMessage2(sval='rah'))
- print('SKIPPING STATIC CHECK')
- # assert static_type_equals(result, _TestMessageR1)
- # assert isinstance(result, _TestMessageR1)
+ obj_r = TestClassR()
+ obj_s = TestClassS(target=obj_r)
+
+ response = obj_s.msg.send(_TestMessage1(ival=0))
+ response2 = obj_s.msg.send(_TestMessage2(sval='rah'))
+ assert static_type_equals(response, _TestMessageR1)
+ assert isinstance(response, _TestMessageR1)
+ assert isinstance(response2, (_TestMessageR1, _TestMessageR2))
+
+ # Remote CleanErrors should come across locally as the same.
+ try:
+ _response3 = obj_s.msg.send(_TestMessage1(ival=1))
+ except Exception as exc:
+ assert isinstance(exc, CleanError)
+ assert str(exc) == 'Testing Clean Error'
+
+ # Other remote errors should come across as RemoteError.
+ with pytest.raises(RemoteError):
+ _response4 = obj_s.msg.send(_TestMessage1(ival=2))
diff --git a/tools/efro/message.py b/tools/efro/message.py
index 29f144a9..5ef4a130 100644
--- a/tools/efro/message.py
+++ b/tools/efro/message.py
@@ -9,12 +9,16 @@ from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar
from dataclasses import dataclass
from enum import Enum
+import inspect
+import logging
import json
+import traceback
from typing_extensions import Annotated
+from efro.error import CleanError, RemoteError
from efro.dataclassio import (ioprepped, is_ioprepped_dataclass, IOAttrs,
- dataclass_to_dict)
+ dataclass_to_dict, dataclass_from_dict)
if TYPE_CHECKING:
from typing import (Dict, Type, Tuple, List, Any, Callable, Optional, Set,
@@ -63,7 +67,8 @@ class MessageProtocol:
message_types: Dict[int, Type[Message]],
type_key: Optional[str] = None,
preserve_clean_errors: bool = True,
- remote_stack_traces: bool = False) -> None:
+ log_remote_exceptions: bool = True,
+ trusted_client: bool = False) -> None:
"""Create a protocol with a given configuration.
Each entry for message_types should contain an ID, a message type,
and all possible response types.
@@ -77,23 +82,24 @@ class MessageProtocol:
on the remote end will result in the same error raised locally.
All other Exception types come across as efro.error.RemoteError.
- If 'remote_stack_traces' is True, stringified remote stack traces will
+ If 'trusted_client' is True, stringified remote stack traces will
be included in the RemoteError. This should only be enabled in cases
where the client is trusted.
"""
- self._message_types_by_id: Dict[int, Type[Message]] = {}
- self._message_ids_by_type: Dict[Type[Message], int] = {}
+ self.message_types_by_id: Dict[int, Type[Message]] = {}
+ self.message_ids_by_type: Dict[Type[Message], int] = {}
for m_id, m_type in message_types.items():
# Make sure only valid message types were passed and each
# id was assigned only once.
assert isinstance(m_id, int)
+ assert m_id >= 0
assert (is_ioprepped_dataclass(m_type)
and issubclass(m_type, Message))
- assert self._message_types_by_id.get(m_id) is None
+ assert self.message_types_by_id.get(m_id) is None
- self._message_types_by_id[m_id] = m_type
- self._message_ids_by_type[m_type] = m_id
+ self.message_types_by_id[m_id] = m_type
+ self.message_ids_by_type[m_type] = m_id
# Make sure all return types are valid and have been assigned
# an ID as well.
@@ -106,21 +112,28 @@ class MessageProtocol:
all_response_types.update(m_rtypes)
for cls in all_response_types:
assert is_ioprepped_dataclass(cls) and issubclass(cls, Message)
- if cls not in self._message_ids_by_type:
+ if cls not in self.message_ids_by_type:
raise ValueError(f'Possible response type {cls}'
f' was not included in message_types.')
self._type_key = type_key
- self._preserve_clean_errors = preserve_clean_errors
- self._remote_stack_traces = remote_stack_traces
+ self.preserve_clean_errors = preserve_clean_errors
+ self.log_remote_exceptions = log_remote_exceptions
+ self.trusted_client = trusted_client
- def message_encode(self, message: Message) -> bytes:
- """Encode a message to bytes for sending."""
+ def message_encode(self,
+ message: Message,
+ is_error: bool = False) -> bytes:
+ """Encode a message to bytes for transport."""
- m_id = self._message_ids_by_type.get(type(message))
- if m_id is None:
- raise TypeError(f'Message type is not registered in Protocol:'
- f' {type(message)}')
+ m_id: Optional[int]
+ if is_error:
+ m_id = -1
+ else:
+ m_id = self.message_ids_by_type.get(type(message))
+ if m_id is None:
+ raise TypeError(f'Message type is not registered in Protocol:'
+ f' {type(message)}')
msgdict = dataclass_to_dict(message)
# Encode type as part of the message dict if desired
@@ -136,9 +149,38 @@ class MessageProtocol:
return json.dumps(out, separators=(',', ':')).encode()
def message_decode(self, data: bytes) -> Message:
- """Decode a message from bytes."""
- print(f'WOULD DECODE MSG FROM RAW: {str(data)}')
- return Message()
+ """Decode a message from bytes.
+
+ If the message represents a remote error, an Exception will
+ be raised.
+ """
+ msgfull = json.loads(data.decode())
+ assert isinstance(msgfull, dict)
+ msgdict: Optional[dict]
+ if self._type_key is not None:
+ m_id = msgfull.pop(self._type_key)
+ msgdict = msgfull
+ assert isinstance(m_id, int)
+ else:
+ m_id = msgfull.get('t')
+ msgdict = msgfull.get('m')
+ assert isinstance(m_id, int)
+ assert isinstance(msgdict, dict)
+
+ # Special case: a remote error occurred. Raise a local Exception.
+ if m_id == -1:
+ err = dataclass_from_dict(RemoteErrorMessage, msgdict)
+ if (self.preserve_clean_errors
+ and err.error_type is RemoteErrorType.CLEAN):
+ raise CleanError(err.error_message)
+ raise RemoteError(err.error_message)
+
+ # Decode this particular type and make sure its valid.
+ msgtype = self.message_types_by_id.get(m_id)
+ if msgtype is None:
+ raise TypeError(f'Got unregistered message type id of {m_id}.')
+
+ return dataclass_from_dict(msgtype, msgdict)
def create_sender_module(self, classname: str) -> str:
""""Create a Python module defining a MessageSender subclass.
@@ -156,23 +198,6 @@ class MessageProtocol:
the protocol.
"""
- def validate_message_type(self, msgtype: Type,
- responsetypes: Sequence[Type]) -> None:
- """Ensure message type associated response types are valid.
- Raises an exception if not.
- """
- if msgtype not in self._message_ids_by_type:
- raise TypeError(f'Message type {msgtype} is not registered'
- f' in this Protocol.')
-
- # Make sure the responses exactly matches what the message expects.
- assert len(set(responsetypes)) == len(responsetypes)
-
- for responsetype in responsetypes:
- if responsetype not in self._message_ids_by_type:
- raise TypeError(f'Response message type {responsetype} is'
- f' not registered in this Protocol.')
-
class MessageSender:
"""Facilitates sending messages to a target and receiving responses.
@@ -207,17 +232,27 @@ class MessageSender:
self._send_raw_message_call = call
return call
- def send(self, bound_obj: Any, message: Message) -> Any:
+ def send(self, bound_obj: Any, message: Message) -> Message:
"""Send a message and receive a response.
+
Will encode the message for transport and call dispatch_raw_message()
"""
if self._send_raw_message_call is None:
raise RuntimeError('send() is unimplemented for this type.')
- encoded = self._protocol.message_encode(message)
- return self._send_raw_message_call(bound_obj, encoded)
+
+ # Only types with possible response types should ever be sent.
+ assert type(message).get_response_types()
+
+ msg_encoded = self._protocol.message_encode(message)
+ response_encoded = self._send_raw_message_call(bound_obj, msg_encoded)
+ response = self._protocol.message_decode(response_encoded)
+ assert isinstance(response, Message)
+ assert type(response) in type(message).get_response_types()
+ return response
def send_bg(self, bound_obj: Any, message: Message) -> Message:
"""Send a message asynchronously and receive a future.
+
The message will be encoded for transport and passed to
dispatch_raw_message from a background thread.
"""
@@ -225,6 +260,7 @@ class MessageSender:
def send_async(self, bound_obj: Any, message: Message) -> Message:
"""Send a message asynchronously using asyncio.
+
The message will be encoded for transport and passed to
dispatch_raw_message_async.
"""
@@ -233,6 +269,7 @@ class MessageSender:
class MessageReceiver:
"""Facilitates receiving & responding to messages from a remote source.
+
This is instantiated at the class level with unbound methods registered
as handlers for different message types in the protocol.
@@ -257,10 +294,13 @@ class MessageReceiver:
def __init__(self, protocol: MessageProtocol) -> None:
self._protocol = protocol
+ self._handlers: Dict[Type[Message], Callable] = {}
# noinspection PyProtectedMember
- def register_handler(self, call: Callable) -> None:
+ def register_handler(self, call: Callable[[Any, Message],
+ Message]) -> None:
"""Register a handler call.
+
The message type handled by the call is determined by its
type annotation.
"""
@@ -268,15 +308,26 @@ class MessageReceiver:
from typing import _GenericAlias # type: ignore
from typing import Union, get_type_hints, get_args
+ sig = inspect.getfullargspec(call)
+
+ # The provided callable should be a method taking one 'msg' arg.
+ expectedsig = ['self', 'msg']
+ if sig.args != expectedsig:
+ raise ValueError(f'Expected callable signature of {expectedsig};'
+ f' got {sig.args}')
+
+ # Check annotation types to determine what message types we handle.
# Return-type annotation can be a Union, but we probably don't
# have it available at runtime. Explicitly pull it in.
anns = get_type_hints(call, localns={'Union': Union})
- msg = anns.get('msg')
- if not isinstance(msg, type):
+ msgtype = anns.get('msg')
+ if not isinstance(msgtype, type):
raise TypeError(
- f'expected a type for "msg" annotation; got {type(msg)}.')
+ f'expected a type for "msg" annotation; got {type(msgtype)}.')
+ assert issubclass(msgtype, Message)
+
ret = anns.get('return')
- rets: Tuple[Type, ...]
+ responsetypes: Tuple[Type, ...]
# Return types can be a single type or a union of types.
if isinstance(ret, _GenericAlias):
@@ -284,27 +335,93 @@ class MessageReceiver:
if not all(isinstance(a, type) for a in targs):
raise TypeError(f'expected only types for "return" annotation;'
f' got {targs}.')
- rets = targs
-
- print(f'LOOKED AT GENERIC ALIAS {targs}')
+ responsetypes = targs
else:
if not isinstance(ret, type):
raise TypeError(f'expected one or more types for'
f' "return" annotation; got a {type(ret)}.')
- rets = (ret, )
+ responsetypes = (ret, )
- print(f'WOULD REGISTER HANDLER! (got {msg} and {rets})')
+ # Make sure our protocol has this message type registered and our
+ # return types exactly match. (Technically we could return a subset
+ # of the supported types; can allow this in the future if it makes
+ # sense).
+ registered_types = self._protocol.message_ids_by_type.keys()
- def handle_raw_message(self, msg: bytes) -> bytes:
- """Should be called when the receiver gets a message.
- The return value is the raw response to the message.
+ if msgtype not in registered_types:
+ raise TypeError(f'Message type {msgtype} is not registered'
+ f' in this Protocol.')
+
+ if msgtype in self._handlers:
+ raise TypeError(f'Message type {msgtype} already has a registered'
+ f' handler.')
+
+ # Make sure the responses exactly matches what the message expects.
+ if set(responsetypes) != set(msgtype.get_response_types()):
+ raise TypeError(
+ f'Provided response types {responsetypes} do not'
+ f' match the set expected by message type {msgtype}: '
+ f'({msgtype.get_response_types()})')
+
+ # Ok; we're good!
+ self._handlers[msgtype] = call
+
+ def validate_handler_completeness(self, warn_only: bool = False) -> None:
+ """Return whether this receiver handles all protocol messages.
+
+ Only messages having possible response types are considered, as
+ those are the only ones that can be sent to a receiver.
"""
- print('RECEIVER WOULD HANDLE RAW MESSAGE')
- del msg # Unused
- return b''
+ for msgtype in self._protocol.message_ids_by_type.keys():
+ if not msgtype.get_response_types():
+ continue
+ if msgtype not in self._handlers:
+ msg = (f'Protocol message {msgtype} not handled'
+ f' by receiver.')
+ if warn_only:
+ logging.warning(msg)
+ raise TypeError(msg)
+
+ def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes:
+ """Decode, handle, and return encoded response for a message."""
+ try:
+ # Decode the incoming message.
+ msg_decoded = self._protocol.message_decode(msg)
+ msgtype = type(msg_decoded)
+
+ # Call the proper handler.
+ handler = self._handlers.get(msgtype)
+ if handler is None:
+ raise RuntimeError(f'Got unhandled message type: {msgtype}.')
+ response = handler(bound_obj, msg_decoded)
+
+ # Re-encode the response.
+ assert isinstance(response, Message)
+ assert type(response) in msgtype.get_response_types()
+ return self._protocol.message_encode(response)
+
+ except Exception as exc:
+
+ if self._protocol.log_remote_exceptions:
+ logging.exception('Error handling message.')
+
+ # If anything goes wrong, return a RemoteErrorMessage instead.
+ if (isinstance(exc, CleanError)
+ and self._protocol.preserve_clean_errors):
+ response = RemoteErrorMessage(error_message=str(exc),
+ error_type=RemoteErrorType.CLEAN)
+ else:
+
+ response = RemoteErrorMessage(
+ error_message=(traceback.format_exc()
+ if self._protocol.trusted_client else
+ 'An unknown error has occurred.'),
+ error_type=RemoteErrorType.OTHER)
+ return self._protocol.message_encode(response, is_error=True)
async def handle_raw_message_async(self, msg: bytes) -> bytes:
"""Should be called when the receiver gets a message.
+
The return value is the raw response to the message.
"""
raise RuntimeError('Unimplemented!')