diff --git a/tools/efro/message.py b/tools/efro/message.py index 3b59f78c..57970b12 100644 --- a/tools/efro/message.py +++ b/tools/efro/message.py @@ -261,39 +261,40 @@ class MessageProtocol: def _get_module_header(self, part: str) -> str: """Return common parts of generated modules.""" imports: Dict[str, List[str]] = {} - for msgtype in self.message_ids_by_type: + for msgtype in list(self.message_ids_by_type) + [Message]: imports.setdefault(msgtype.__module__, []).append(msgtype.__name__) - importlines = '' + for rsp_tp in list(self.response_ids_by_type) + [Response]: + # Skip these as they don't actually show up in code. + if rsp_tp is EmptyResponse or rsp_tp is ErrorResponse: + continue + imports.setdefault(rsp_tp.__module__, []).append(rsp_tp.__name__) + importlines2 = '' for module, names in sorted(imports.items()): jnames = ', '.join(names) line = f'from {module} import {jnames}' if len(line) > 79: # Recreate in a wrapping-friendly form. line = f'from {module} import ({jnames})' - importlines += f'{line}\n' + importlines2 += f' {line}\n' if part == 'sender': - importlines = ( - f'from efro.message import MessageSender\n{importlines}') - tpimports = 'from efro.message import Message, Response' + importlines1 = 'from efro.message import MessageSender' else: - importlines = ( - f'from efro.message import MessageReceiver\n{importlines}') - tpimports = 'from efro.message import Message, Response' + importlines1 = 'from efro.message import MessageReceiver' out = ('# Released under the MIT License. See LICENSE for details.\n' f'#\n' - f'"""Auto-generated {part} module."""\n' + f'"""Auto-generated {part} module. Do not edit by hand."""\n' f'\n' f'from __future__ import annotations\n' f'\n' f'from typing import TYPE_CHECKING, overload\n' f'\n' - f'{importlines}' + f'{importlines1}\n' f'\n' f'if TYPE_CHECKING:\n' - f' from typing import Union\n' - f' {tpimports}\n' + f' from typing import Union, Any, Optional\n' + f'{importlines2}' f'\n' f'\n') return out diff --git a/tools/efrotools/code.py b/tools/efrotools/code.py index 8dc4a28e..9fe29f6c 100644 --- a/tools/efrotools/code.py +++ b/tools/efrotools/code.py @@ -189,11 +189,15 @@ def format_yapf(projroot: Path, full: bool) -> None: flush=True) -def format_yapf_text(projroot: Path, code: str) -> str: - """Run yapf formatting on the provided code.""" - del projroot # Unused. - print('WOULD DO YAPF') - return code +def format_yapf_str(projroot: Path, code: str) -> str: + """Run yapf formatting on the provided inline code.""" + from efrotools import PYVER + out = subprocess.run([f'python{PYVER}', '-m', 'yapf'], + capture_output=True, + check=True, + input=code.encode(), + cwd=projroot) + return out.stdout.decode() def _should_include_script(fnamefull: str) -> bool: