ballistica/tests/test_efro/test_rpc.py
2022-03-18 22:16:38 -05:00

377 lines
12 KiB
Python

# Released under the MIT License. See LICENSE for details.
#
"""Testing rpc functionality."""
from __future__ import annotations
import time
import random
import asyncio
import weakref
from enum import unique, Enum
from typing import TYPE_CHECKING
from dataclasses import dataclass
import pytest
from efro.rpc import RPCEndpoint
from efro.error import CommunicationError
from efro.dataclassio import ioprepped, dataclass_from_json, dataclass_to_json
if TYPE_CHECKING:
from typing import Optional, Awaitable, Callable
ADDR = '127.0.0.1'
# Randomize this a bit to avoid failing on parallel testing.
# Ideally we should let the OS pick the address and pass the
# resulting one to the client.
PORT = random.randrange(8000, 16000)
SLOW_WAIT = 1.0
@unique
class _MessageType(Enum):
TEST1 = 't1'
RESPONSE1 = 'r1'
TEST2 = '2'
RESPONSE2 = 'r2'
TEST_SLOW = 'ts'
RESPONSE_SLOW = 'rs'
@ioprepped
@dataclass
class _Message:
messagetype: _MessageType
class _ServerClientCommon:
def __init__(self, keepalive_interval: float, keepalive_timeout: float,
debug_print: bool) -> None:
self._endpoint: Optional[RPCEndpoint] = None
self._keepalive_interval = keepalive_interval
self._keepalive_timeout = keepalive_timeout
self._debug_print = debug_print
def has_endpoint(self) -> bool:
"""Is our endpoint up yet?"""
return self._endpoint is not None
@property
def endpoint(self) -> RPCEndpoint:
"""Our endpoint."""
if self._endpoint is None:
raise RuntimeError('Expected endpoint to exist.')
return self._endpoint
async def send_message(self,
message: _Message,
timeout: Optional[float] = None) -> _Message:
"""Send high level messages."""
assert self._endpoint is not None
response = await self._endpoint.send_message(
dataclass_to_json(message).encode(), timeout=timeout)
return dataclass_from_json(_Message, response.decode())
async def handle_message(self, msg: _Message) -> _Message:
"""Handle a high-level message."""
if msg.messagetype is _MessageType.TEST1:
return _Message(_MessageType.RESPONSE1)
if msg.messagetype is _MessageType.TEST2:
return _Message(_MessageType.RESPONSE2)
if msg.messagetype is _MessageType.TEST_SLOW:
await asyncio.sleep(SLOW_WAIT)
return _Message(_MessageType.RESPONSE_SLOW)
raise RuntimeError(f'Got unexpected message type: {msg.messagetype}')
async def _handle_raw_message(self, message: bytes) -> bytes:
msgobj = dataclass_from_json(_Message, message.decode())
rspobj = await self.handle_message(msgobj)
return dataclass_to_json(rspobj).encode()
class _Server(_ServerClientCommon):
def __init__(self, keepalive_interval: float, keepalive_timeout: float,
debug_print: bool) -> None:
super().__init__(keepalive_interval=keepalive_interval,
keepalive_timeout=keepalive_timeout,
debug_print=debug_print)
self.listener: Optional[asyncio.base_events.Server] = None
async def start(self) -> None:
"""Start serving. Call this before run()."""
assert self.listener is None
self.listener = await asyncio.start_server(self._handle_client, ADDR,
PORT)
async def run(self) -> None:
"""Do the thing."""
assert self.listener is not None
assert self._endpoint is None
async with self.listener:
try:
await self.listener.serve_forever()
except asyncio.CancelledError:
pass
async def _handle_client(self, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter) -> None:
assert self._endpoint is None
# Note to self: passing ourself as a handler creates a dependency
# loop; in production code we'd probably want to store this as a
# weak ref or whatnot to keep teardown deterministic.
self._endpoint = RPCEndpoint(
self._handle_raw_message,
reader,
writer,
keepalive_interval=self._keepalive_interval,
keepalive_timeout=self._keepalive_timeout,
debug_print=self._debug_print,
label='test_rpc_server')
await self._endpoint.run()
class _Client(_ServerClientCommon):
def __init__(self, keepalive_interval: float, keepalive_timeout: float,
debug_print: bool) -> None:
super().__init__(keepalive_interval=keepalive_interval,
keepalive_timeout=keepalive_timeout,
debug_print=debug_print)
async def run(self) -> None:
"""Do the thing."""
reader, writer = await asyncio.open_connection(ADDR, PORT)
# Note to self: passing ourself as a handler creates a dependency
# loop; in production code we'd probably want to store this as a
# weak ref or whatnot to keep teardown deterministic.
self._endpoint = RPCEndpoint(
self._handle_raw_message,
reader,
writer,
keepalive_interval=self._keepalive_interval,
keepalive_timeout=self._keepalive_timeout,
debug_print=self._debug_print,
label='test_rpc_client')
await self._endpoint.run()
class _Tester:
def __init__(
self,
keepalive_interval: float = RPCEndpoint.DEFAULT_KEEPALIVE_INTERVAL,
keepalive_timeout: float = RPCEndpoint.DEFAULT_KEEPALIVE_TIMEOUT,
server_debug_print: bool = True,
client_debug_print: bool = True) -> None:
self.client = _Client(keepalive_interval=keepalive_interval,
keepalive_timeout=keepalive_timeout,
debug_print=client_debug_print)
self.server = _Server(keepalive_interval=keepalive_interval,
keepalive_timeout=keepalive_timeout,
debug_print=server_debug_print)
# noinspection PyProtectedMember
def run(self, testcall: Awaitable[None]) -> None:
"""Run our test."""
asyncio.run(self._run(testcall), debug=True)
# Make sure the endpoints go down immediately when we remove our
# only refs to them.
server_endpoint_ref = weakref.ref(self.server.endpoint)
client_endpoint_ref = weakref.ref(self.client.endpoint)
del self.client._endpoint
del self.server._endpoint
assert server_endpoint_ref() is None, 'Server did not go down cleanly.'
assert client_endpoint_ref() is None, 'Client did not go down cleanly.'
async def _run(self, testcall: Awaitable[None]) -> None:
# Give server a chance to spin up before kicking off client.
await self.server.start()
# Now run our server, our client, and our tests simultaneously.
await asyncio.gather(
self.server.run(),
self.client.run(),
self._run_test(testcall),
)
async def _run_test(self, testcall: Awaitable[None]) -> None:
"""Set up before and tear down after a test call."""
assert self.server.listener is not None
# Wait until the client has connected.
while not self.server.has_endpoint():
await asyncio.sleep(0.01)
print('test_rpc test call starting...')
# Do the thing.
await testcall
print('test_rpc test call completed; tearing down...')
# Close both our listener socket and our established endpoint;
# this should break us out of our loop.
self.server.endpoint.close()
await self.server.endpoint.wait_closed()
self.server.listener.close()
await self.server.listener.wait_closed()
def test_keepalive_fail() -> None:
"""Test keepalive timeout."""
kinterval = 0.05
ktimeout = 0.25
tester = _Tester(keepalive_interval=kinterval, keepalive_timeout=ktimeout)
async def _do_it() -> None:
# Tell our client to not send keepalives.
tester.client.endpoint.test_suppress_keepalives = True
# Sleep just past the keepalive timeout and make sure the endpoint
# IS going down.
await asyncio.sleep(ktimeout * 1.25)
assert tester.server.endpoint.is_closing()
tester.run(_do_it())
def test_keepalive_success() -> None:
"""Test keepalive non-timeout."""
kinterval = 0.05
ktimeout = 0.25
tester = _Tester(keepalive_interval=kinterval, keepalive_timeout=ktimeout)
async def _do_it() -> None:
# Sleep just past the keepalive timeout and make sure the endpoint
# is NOT going down
await asyncio.sleep(ktimeout * 1.25)
assert not tester.server.endpoint.is_closing()
tester.run(_do_it())
def test_simple_messages() -> None:
"""Test basic messages and responses."""
tester = _Tester()
async def _do_it() -> None:
# Send some messages both directions and make sure we get the expected
# response types.
resp = await tester.server.send_message(_Message(_MessageType.TEST1))
assert resp.messagetype is _MessageType.RESPONSE1
resp = await tester.client.send_message(_Message(_MessageType.TEST1))
assert resp.messagetype is _MessageType.RESPONSE1
resp = await tester.server.send_message(_Message(_MessageType.TEST2))
assert resp.messagetype is _MessageType.RESPONSE2
resp = await tester.client.send_message(_Message(_MessageType.TEST2))
assert resp.messagetype is _MessageType.RESPONSE2
tester.run(_do_it())
def test_simultaneous_messages() -> None:
"""Test basic messages and responses."""
tester = _Tester()
async def _do_it() -> None:
# Send a bunch of messages both ways at once and make sure
# they all come through simultaneously-ish.
starttime = time.monotonic()
results = await asyncio.gather(
tester.client.send_message(_Message(_MessageType.TEST_SLOW)),
tester.server.send_message(_Message(_MessageType.TEST_SLOW)),
tester.client.send_message(_Message(_MessageType.TEST_SLOW)),
tester.server.send_message(_Message(_MessageType.TEST_SLOW)),
tester.client.send_message(_Message(_MessageType.TEST_SLOW)),
tester.server.send_message(_Message(_MessageType.TEST_SLOW)),
)
# This should all go through in the same time that 1 goes through in.
assert (time.monotonic() - starttime) < 1.1 * SLOW_WAIT
# Make sure we got all correct responses.
assert all(r.messagetype is _MessageType.RESPONSE_SLOW
for r in results)
# They should all be uniquely created message objects.
assert len(set(id(r) for r in results)) == len(results)
tester.run(_do_it())
def test_message_timeout() -> None:
"""Test sends timing out."""
tester = _Tester()
async def _do_it() -> None:
# This message should return after a short wait.
resp = await tester.server.send_message(
_Message(_MessageType.TEST_SLOW))
assert resp.messagetype is _MessageType.RESPONSE_SLOW
# This message should time out.
with pytest.raises(CommunicationError):
resp = await tester.server.send_message(
_Message(_MessageType.TEST_SLOW),
timeout=0.5,
)
tester.run(_do_it())
def test_server_interrupt() -> None:
"""Test server dying during message send."""
tester = _Tester()
async def _do_it() -> None:
async def _kill_connection() -> None:
await asyncio.sleep(0.2)
tester.server.endpoint.close()
asyncio.create_task(_kill_connection())
with pytest.raises(CommunicationError):
await tester.server.send_message(_Message(_MessageType.TEST_SLOW))
tester.run(_do_it())
def test_client_interrupt() -> None:
"""Test client dying during message send."""
tester = _Tester()
async def _do_it() -> None:
async def _kill_connection() -> None:
await asyncio.sleep(0.2)
tester.client.endpoint.close()
asyncio.create_task(_kill_connection())
with pytest.raises(CommunicationError):
await tester.server.send_message(_Message(_MessageType.TEST_SLOW))
tester.run(_do_it())