mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-01-19 21:37:57 +08:00
377 lines
12 KiB
Python
377 lines
12 KiB
Python
# Released under the MIT License. See LICENSE for details.
|
|
#
|
|
"""Testing rpc functionality."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
import random
|
|
import asyncio
|
|
import weakref
|
|
from enum import unique, Enum
|
|
from typing import TYPE_CHECKING
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
|
|
from efro.rpc import RPCEndpoint
|
|
from efro.error import CommunicationError
|
|
from efro.dataclassio import ioprepped, dataclass_from_json, dataclass_to_json
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Optional, Awaitable, Callable
|
|
|
|
ADDR = '127.0.0.1'
|
|
# Randomize this a bit to avoid failing on parallel testing.
|
|
# Ideally we should let the OS pick the address and pass the
|
|
# resulting one to the client.
|
|
PORT = random.randrange(8000, 16000)
|
|
|
|
SLOW_WAIT = 1.0
|
|
|
|
|
|
@unique
|
|
class _MessageType(Enum):
|
|
TEST1 = 't1'
|
|
RESPONSE1 = 'r1'
|
|
TEST2 = '2'
|
|
RESPONSE2 = 'r2'
|
|
TEST_SLOW = 'ts'
|
|
RESPONSE_SLOW = 'rs'
|
|
|
|
|
|
@ioprepped
|
|
@dataclass
|
|
class _Message:
|
|
messagetype: _MessageType
|
|
|
|
|
|
class _ServerClientCommon:
|
|
|
|
def __init__(self, keepalive_interval: float, keepalive_timeout: float,
|
|
debug_print: bool) -> None:
|
|
self._endpoint: Optional[RPCEndpoint] = None
|
|
self._keepalive_interval = keepalive_interval
|
|
self._keepalive_timeout = keepalive_timeout
|
|
self._debug_print = debug_print
|
|
|
|
def has_endpoint(self) -> bool:
|
|
"""Is our endpoint up yet?"""
|
|
return self._endpoint is not None
|
|
|
|
@property
|
|
def endpoint(self) -> RPCEndpoint:
|
|
"""Our endpoint."""
|
|
if self._endpoint is None:
|
|
raise RuntimeError('Expected endpoint to exist.')
|
|
return self._endpoint
|
|
|
|
async def send_message(self,
|
|
message: _Message,
|
|
timeout: Optional[float] = None) -> _Message:
|
|
"""Send high level messages."""
|
|
assert self._endpoint is not None
|
|
response = await self._endpoint.send_message(
|
|
dataclass_to_json(message).encode(), timeout=timeout)
|
|
return dataclass_from_json(_Message, response.decode())
|
|
|
|
async def handle_message(self, msg: _Message) -> _Message:
|
|
"""Handle a high-level message."""
|
|
|
|
if msg.messagetype is _MessageType.TEST1:
|
|
return _Message(_MessageType.RESPONSE1)
|
|
|
|
if msg.messagetype is _MessageType.TEST2:
|
|
return _Message(_MessageType.RESPONSE2)
|
|
|
|
if msg.messagetype is _MessageType.TEST_SLOW:
|
|
await asyncio.sleep(SLOW_WAIT)
|
|
return _Message(_MessageType.RESPONSE_SLOW)
|
|
|
|
raise RuntimeError(f'Got unexpected message type: {msg.messagetype}')
|
|
|
|
async def _handle_raw_message(self, message: bytes) -> bytes:
|
|
msgobj = dataclass_from_json(_Message, message.decode())
|
|
rspobj = await self.handle_message(msgobj)
|
|
return dataclass_to_json(rspobj).encode()
|
|
|
|
|
|
class _Server(_ServerClientCommon):
|
|
|
|
def __init__(self, keepalive_interval: float, keepalive_timeout: float,
|
|
debug_print: bool) -> None:
|
|
super().__init__(keepalive_interval=keepalive_interval,
|
|
keepalive_timeout=keepalive_timeout,
|
|
debug_print=debug_print)
|
|
self.listener: Optional[asyncio.base_events.Server] = None
|
|
|
|
async def start(self) -> None:
|
|
"""Start serving. Call this before run()."""
|
|
assert self.listener is None
|
|
self.listener = await asyncio.start_server(self._handle_client, ADDR,
|
|
PORT)
|
|
|
|
async def run(self) -> None:
|
|
"""Do the thing."""
|
|
assert self.listener is not None
|
|
assert self._endpoint is None
|
|
async with self.listener:
|
|
try:
|
|
await self.listener.serve_forever()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def _handle_client(self, reader: asyncio.StreamReader,
|
|
writer: asyncio.StreamWriter) -> None:
|
|
assert self._endpoint is None
|
|
|
|
# Note to self: passing ourself as a handler creates a dependency
|
|
# loop; in production code we'd probably want to store this as a
|
|
# weak ref or whatnot to keep teardown deterministic.
|
|
self._endpoint = RPCEndpoint(
|
|
self._handle_raw_message,
|
|
reader,
|
|
writer,
|
|
keepalive_interval=self._keepalive_interval,
|
|
keepalive_timeout=self._keepalive_timeout,
|
|
debug_print=self._debug_print,
|
|
label='test_rpc_server')
|
|
|
|
await self._endpoint.run()
|
|
|
|
|
|
class _Client(_ServerClientCommon):
|
|
|
|
def __init__(self, keepalive_interval: float, keepalive_timeout: float,
|
|
debug_print: bool) -> None:
|
|
super().__init__(keepalive_interval=keepalive_interval,
|
|
keepalive_timeout=keepalive_timeout,
|
|
debug_print=debug_print)
|
|
|
|
async def run(self) -> None:
|
|
"""Do the thing."""
|
|
reader, writer = await asyncio.open_connection(ADDR, PORT)
|
|
# Note to self: passing ourself as a handler creates a dependency
|
|
# loop; in production code we'd probably want to store this as a
|
|
# weak ref or whatnot to keep teardown deterministic.
|
|
self._endpoint = RPCEndpoint(
|
|
self._handle_raw_message,
|
|
reader,
|
|
writer,
|
|
keepalive_interval=self._keepalive_interval,
|
|
keepalive_timeout=self._keepalive_timeout,
|
|
debug_print=self._debug_print,
|
|
label='test_rpc_client')
|
|
await self._endpoint.run()
|
|
|
|
|
|
class _Tester:
|
|
|
|
def __init__(
|
|
self,
|
|
keepalive_interval: float = RPCEndpoint.DEFAULT_KEEPALIVE_INTERVAL,
|
|
keepalive_timeout: float = RPCEndpoint.DEFAULT_KEEPALIVE_TIMEOUT,
|
|
server_debug_print: bool = True,
|
|
client_debug_print: bool = True) -> None:
|
|
self.client = _Client(keepalive_interval=keepalive_interval,
|
|
keepalive_timeout=keepalive_timeout,
|
|
debug_print=client_debug_print)
|
|
self.server = _Server(keepalive_interval=keepalive_interval,
|
|
keepalive_timeout=keepalive_timeout,
|
|
debug_print=server_debug_print)
|
|
|
|
# noinspection PyProtectedMember
|
|
def run(self, testcall: Awaitable[None]) -> None:
|
|
"""Run our test."""
|
|
|
|
asyncio.run(self._run(testcall), debug=True)
|
|
|
|
# Make sure the endpoints go down immediately when we remove our
|
|
# only refs to them.
|
|
server_endpoint_ref = weakref.ref(self.server.endpoint)
|
|
client_endpoint_ref = weakref.ref(self.client.endpoint)
|
|
del self.client._endpoint
|
|
del self.server._endpoint
|
|
assert server_endpoint_ref() is None, 'Server did not go down cleanly.'
|
|
assert client_endpoint_ref() is None, 'Client did not go down cleanly.'
|
|
|
|
async def _run(self, testcall: Awaitable[None]) -> None:
|
|
|
|
# Give server a chance to spin up before kicking off client.
|
|
await self.server.start()
|
|
|
|
# Now run our server, our client, and our tests simultaneously.
|
|
await asyncio.gather(
|
|
self.server.run(),
|
|
self.client.run(),
|
|
self._run_test(testcall),
|
|
)
|
|
|
|
async def _run_test(self, testcall: Awaitable[None]) -> None:
|
|
"""Set up before and tear down after a test call."""
|
|
assert self.server.listener is not None
|
|
|
|
# Wait until the client has connected.
|
|
while not self.server.has_endpoint():
|
|
await asyncio.sleep(0.01)
|
|
|
|
print('test_rpc test call starting...')
|
|
|
|
# Do the thing.
|
|
await testcall
|
|
|
|
print('test_rpc test call completed; tearing down...')
|
|
|
|
# Close both our listener socket and our established endpoint;
|
|
# this should break us out of our loop.
|
|
self.server.endpoint.close()
|
|
await self.server.endpoint.wait_closed()
|
|
|
|
self.server.listener.close()
|
|
await self.server.listener.wait_closed()
|
|
|
|
|
|
def test_keepalive_fail() -> None:
|
|
"""Test keepalive timeout."""
|
|
kinterval = 0.05
|
|
ktimeout = 0.25
|
|
tester = _Tester(keepalive_interval=kinterval, keepalive_timeout=ktimeout)
|
|
|
|
async def _do_it() -> None:
|
|
|
|
# Tell our client to not send keepalives.
|
|
tester.client.endpoint.test_suppress_keepalives = True
|
|
|
|
# Sleep just past the keepalive timeout and make sure the endpoint
|
|
# IS going down.
|
|
await asyncio.sleep(ktimeout * 1.25)
|
|
assert tester.server.endpoint.is_closing()
|
|
|
|
tester.run(_do_it())
|
|
|
|
|
|
def test_keepalive_success() -> None:
|
|
"""Test keepalive non-timeout."""
|
|
kinterval = 0.05
|
|
ktimeout = 0.25
|
|
tester = _Tester(keepalive_interval=kinterval, keepalive_timeout=ktimeout)
|
|
|
|
async def _do_it() -> None:
|
|
|
|
# Sleep just past the keepalive timeout and make sure the endpoint
|
|
# is NOT going down
|
|
await asyncio.sleep(ktimeout * 1.25)
|
|
assert not tester.server.endpoint.is_closing()
|
|
|
|
tester.run(_do_it())
|
|
|
|
|
|
def test_simple_messages() -> None:
|
|
"""Test basic messages and responses."""
|
|
tester = _Tester()
|
|
|
|
async def _do_it() -> None:
|
|
|
|
# Send some messages both directions and make sure we get the expected
|
|
# response types.
|
|
|
|
resp = await tester.server.send_message(_Message(_MessageType.TEST1))
|
|
assert resp.messagetype is _MessageType.RESPONSE1
|
|
|
|
resp = await tester.client.send_message(_Message(_MessageType.TEST1))
|
|
assert resp.messagetype is _MessageType.RESPONSE1
|
|
|
|
resp = await tester.server.send_message(_Message(_MessageType.TEST2))
|
|
assert resp.messagetype is _MessageType.RESPONSE2
|
|
|
|
resp = await tester.client.send_message(_Message(_MessageType.TEST2))
|
|
assert resp.messagetype is _MessageType.RESPONSE2
|
|
|
|
tester.run(_do_it())
|
|
|
|
|
|
def test_simultaneous_messages() -> None:
|
|
"""Test basic messages and responses."""
|
|
tester = _Tester()
|
|
|
|
async def _do_it() -> None:
|
|
|
|
# Send a bunch of messages both ways at once and make sure
|
|
# they all come through simultaneously-ish.
|
|
starttime = time.monotonic()
|
|
results = await asyncio.gather(
|
|
tester.client.send_message(_Message(_MessageType.TEST_SLOW)),
|
|
tester.server.send_message(_Message(_MessageType.TEST_SLOW)),
|
|
tester.client.send_message(_Message(_MessageType.TEST_SLOW)),
|
|
tester.server.send_message(_Message(_MessageType.TEST_SLOW)),
|
|
tester.client.send_message(_Message(_MessageType.TEST_SLOW)),
|
|
tester.server.send_message(_Message(_MessageType.TEST_SLOW)),
|
|
)
|
|
|
|
# This should all go through in the same time that 1 goes through in.
|
|
assert (time.monotonic() - starttime) < 1.1 * SLOW_WAIT
|
|
|
|
# Make sure we got all correct responses.
|
|
assert all(r.messagetype is _MessageType.RESPONSE_SLOW
|
|
for r in results)
|
|
|
|
# They should all be uniquely created message objects.
|
|
assert len(set(id(r) for r in results)) == len(results)
|
|
|
|
tester.run(_do_it())
|
|
|
|
|
|
def test_message_timeout() -> None:
|
|
"""Test sends timing out."""
|
|
tester = _Tester()
|
|
|
|
async def _do_it() -> None:
|
|
|
|
# This message should return after a short wait.
|
|
resp = await tester.server.send_message(
|
|
_Message(_MessageType.TEST_SLOW))
|
|
assert resp.messagetype is _MessageType.RESPONSE_SLOW
|
|
|
|
# This message should time out.
|
|
with pytest.raises(CommunicationError):
|
|
resp = await tester.server.send_message(
|
|
_Message(_MessageType.TEST_SLOW),
|
|
timeout=0.5,
|
|
)
|
|
|
|
tester.run(_do_it())
|
|
|
|
|
|
def test_server_interrupt() -> None:
|
|
"""Test server dying during message send."""
|
|
tester = _Tester()
|
|
|
|
async def _do_it() -> None:
|
|
|
|
async def _kill_connection() -> None:
|
|
await asyncio.sleep(0.2)
|
|
tester.server.endpoint.close()
|
|
|
|
asyncio.create_task(_kill_connection())
|
|
with pytest.raises(CommunicationError):
|
|
await tester.server.send_message(_Message(_MessageType.TEST_SLOW))
|
|
|
|
tester.run(_do_it())
|
|
|
|
|
|
def test_client_interrupt() -> None:
|
|
"""Test client dying during message send."""
|
|
tester = _Tester()
|
|
|
|
async def _do_it() -> None:
|
|
|
|
async def _kill_connection() -> None:
|
|
await asyncio.sleep(0.2)
|
|
tester.client.endpoint.close()
|
|
|
|
asyncio.create_task(_kill_connection())
|
|
with pytest.raises(CommunicationError):
|
|
await tester.server.send_message(_Message(_MessageType.TEST_SLOW))
|
|
|
|
tester.run(_do_it())
|