diff --git a/tests/test_efro/test_message.py b/tests/test_efro/test_message.py index a2100cd9..2a24a008 100644 --- a/tests/test_efro/test_message.py +++ b/tests/test_efro/test_message.py @@ -179,7 +179,7 @@ class _BoundTestSyncMessageReceiver: self._obj = obj self._receiver = receiver - def handle_raw_message(self, message: bytes) -> bytes: + def handle_raw_message(self, message: str) -> str: """Synchronously handle a raw incoming message.""" return self._receiver.handle_raw_message(self._obj, message) @@ -239,7 +239,7 @@ class _BoundTestAsyncMessageReceiver: self._obj = obj self._receiver = receiver - async def handle_raw_message(self, message: bytes) -> bytes: + async def handle_raw_message(self, message: str) -> str: """Asynchronously handle a raw incoming message.""" return await self._receiver.handle_raw_message_async( self._obj, message) @@ -396,16 +396,16 @@ def test_full_pipeline() -> None: self._target = target @msg.send_method - def _send_raw_message(self, data: bytes) -> bytes: - """Handle synchronous sending of raw message data.""" + def _send_raw_message(self, data: str) -> str: + """Handle synchronous sending of raw json message data.""" # Just talk directly to the receiver for this example. # (currently only support synchronous receivers) assert isinstance(self._target, TestClassRSync) return self._target.receiver.handle_raw_message(data) @msg.send_async_method - async def _send_raw_message_async(self, data: bytes) -> bytes: - """Handle asynchronous sending of raw message data.""" + async def _send_raw_message_async(self, data: str) -> str: + """Handle asynchronous sending of raw json message data.""" # Just talk directly to the receiver for this example. # (we can do sync or async receivers) if isinstance(self._target, TestClassRSync): diff --git a/tools/efro/message.py b/tools/efro/message.py index 61e461ed..848f74ec 100644 --- a/tools/efro/message.py +++ b/tools/efro/message.py @@ -179,17 +179,17 @@ class MessageProtocol: self.log_remote_exceptions = log_remote_exceptions self.trusted_sender = trusted_sender - def encode_message(self, message: Message) -> bytes: - """Encode a message to bytes for transport.""" + def encode_message(self, message: Message) -> str: + """Encode a message to a json string for transport.""" return self._encode(message, self.message_ids_by_type, 'message') - def encode_response(self, response: Response) -> bytes: - """Encode a response to bytes for transport.""" + def encode_response(self, response: Response) -> str: + """Encode a response to a json string for transport.""" return self._encode(response, self.response_ids_by_type, 'response') def _encode(self, message: Any, ids_by_type: Dict[Type, int], - opname: str) -> bytes: - """Encode a message to bytes for transport.""" + opname: str) -> str: + """Encode a message to a json string for transport.""" m_id: Optional[int] = ids_by_type.get(type(message)) if m_id is None: @@ -207,24 +207,24 @@ class MessageProtocol: out = msgdict else: out = {'m': msgdict, 't': m_id} - return json.dumps(out, separators=(',', ':')).encode() + return json.dumps(out, separators=(',', ':')) - def decode_message(self, data: bytes) -> Message: - """Decode a message from bytes.""" + def decode_message(self, data: str) -> Message: + """Decode a message from a json string.""" out = self._decode(data, self.message_types_by_id, 'message') assert isinstance(out, Message) return out - def decode_response(self, data: bytes) -> Optional[Response]: - """Decode a response from bytes.""" + def decode_response(self, data: str) -> Optional[Response]: + """Decode a response from a json string.""" out = self._decode(data, self.response_types_by_id, 'response') assert isinstance(out, (Response, type(None))) return out - def _decode(self, data: bytes, types_by_id: Dict[int, Type], + def _decode(self, data: str, types_by_id: Dict[int, Type], opname: str) -> Any: - """Decode a message from bytes.""" - msgfull = json.loads(data.decode()) + """Decode a message from a json string.""" + msgfull = json.loads(data) assert isinstance(msgfull, dict) msgdict: Optional[dict] if self._type_key is not None: @@ -487,8 +487,8 @@ class MessageProtocol: if is_async: out += ( '\n' - ' async def handle_raw_message(self, message: bytes)' - ' -> bytes:\n' + ' async def handle_raw_message(self, message: str)' + ' -> str:\n' ' """Asynchronously handle a raw incoming message."""\n' ' return await' ' self._receiver.handle_raw_message_async(\n' @@ -496,7 +496,7 @@ class MessageProtocol: else: out += ( '\n' - ' def handle_raw_message(self, message: bytes) -> bytes:\n' + ' def handle_raw_message(self, message: str) -> str:\n' ' """Synchronously handle a raw incoming message."""\n' ' return self._receiver.handle_raw_message' '(self._obj, message)\n') @@ -515,7 +515,7 @@ class MessageSender: msg = MyMessageSender(some_protocol) @msg.sendmethod - def send_raw_message(self, message: bytes) -> bytes: + def send_raw_message(self, message: str) -> str: # Actually send the message here. # MyMessageSender class should provide overloads for send(), send_bg(), @@ -526,22 +526,21 @@ class MessageSender: def __init__(self, protocol: MessageProtocol) -> None: self._protocol = protocol - self._send_raw_message_call: Optional[Callable[[Any, bytes], - bytes]] = None + self._send_raw_message_call: Optional[Callable[[Any, str], str]] = None self._send_async_raw_message_call: Optional[Callable[ - [Any, bytes], Awaitable[bytes]]] = None + [Any, str], Awaitable[str]]] = None def send_method( - self, call: Callable[[Any, bytes], - bytes]) -> Callable[[Any, bytes], bytes]: + self, call: Callable[[Any, str], + str]) -> Callable[[Any, str], str]: """Function decorator for setting raw send method.""" assert self._send_raw_message_call is None self._send_raw_message_call = call return call def send_async_method( - self, call: Callable[[Any, bytes], Awaitable[bytes]] - ) -> Callable[[Any, bytes], Awaitable[bytes]]: + self, call: Callable[[Any, str], Awaitable[str]] + ) -> Callable[[Any, str], Awaitable[str]]: """Function decorator for setting raw send-async method.""" assert self._send_async_raw_message_call is None self._send_async_raw_message_call = call @@ -711,7 +710,7 @@ class MessageReceiver: raise TypeError(msg) def _decode_incoming_message(self, - msg: bytes) -> Tuple[Message, Type[Message]]: + msg: str) -> Tuple[Message, Type[Message]]: # Decode the incoming message. msg_decoded = self._protocol.decode_message(msg) msgtype = type(msg_decoded) @@ -719,7 +718,7 @@ class MessageReceiver: return msg_decoded, msgtype def _encode_response(self, response: Optional[Response], - msgtype: Type[Message]) -> bytes: + msgtype: Type[Message]) -> str: # A return value of None equals EmptyResponse. if response is None: @@ -732,7 +731,7 @@ class MessageReceiver: assert type(response) in msgtype.get_response_types() return self._protocol.encode_response(response) - def _handle_raw_message_error(self, exc: Exception) -> bytes: + def _handle_raw_message_error(self, exc: Exception) -> str: if self._protocol.log_remote_exceptions: logging.exception('Error handling message.') @@ -749,8 +748,8 @@ class MessageReceiver: error_type=ErrorType.OTHER) return self._protocol.encode_response(err_response) - def handle_raw_message(self, bound_obj: Any, msg: bytes) -> bytes: - """Decode, handle, and return an encoded response for a message.""" + def handle_raw_message(self, bound_obj: Any, msg: str) -> str: + """Decode, handle, and return an response for a message.""" assert not self.is_async, "can't call sync handler on async receiver" try: msg_decoded, msgtype = self._decode_incoming_message(msg) @@ -763,8 +762,7 @@ class MessageReceiver: except Exception as exc: return self._handle_raw_message_error(exc) - async def handle_raw_message_async(self, bound_obj: Any, - msg: bytes) -> bytes: + async def handle_raw_message_async(self, bound_obj: Any, msg: str) -> str: """Should be called when the receiver gets a message. The return value is the raw response to the message.