2024-04-23 13:44:40 -07:00

192 lines
5.2 KiB
Python

# Released under the MIT License. See LICENSE for details.
#
"""Misc util calls/etc.
Ideally the stuff in here should migrate to more descriptive module names.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, overload
if TYPE_CHECKING:
from typing import Sequence, Literal
from pathlib import Path
def explicit_bool(value: bool) -> bool:
"""Simply return input value; can avoid unreachable-code type warnings."""
return value
def extract_flag(args: list[str], name: str) -> bool:
"""Given a list of args and a flag name, returns whether it is present.
The arg flag, if present, is removed from the arg list.
"""
from efro.error import CleanError
count = args.count(name)
if count > 1:
raise CleanError(f'Flag {name} passed multiple times.')
if not count:
return False
args.remove(name)
return True
@overload
def extract_arg(
args: list[str], name: str, required: Literal[False] = False
) -> str | None: ...
@overload
def extract_arg(args: list[str], name: str, required: Literal[True]) -> str: ...
def extract_arg(
args: list[str], name: str, required: bool = False
) -> str | None:
"""Given a list of args and an arg name, returns a value.
The arg flag and value are removed from the arg list.
raises CleanErrors on any problems.
"""
from efro.error import CleanError
count = args.count(name)
if not count:
if required:
raise CleanError(f'Required argument {name} not passed.')
return None
if count > 1:
raise CleanError(f'Arg {name} passed multiple times.')
argindex = args.index(name)
if argindex + 1 >= len(args):
raise CleanError(f'No value passed after {name} arg.')
val = args[argindex + 1]
del args[argindex : argindex + 2]
return val
def replace_section(
text: str,
begin_marker: str,
end_marker: str,
replace_text: str = '',
keep_markers: bool = False,
error_if_missing: bool = True,
) -> str:
"""Replace all text between two marker strings (including the markers)."""
if begin_marker not in text:
if error_if_missing:
raise RuntimeError(f"Marker not found in text: '{begin_marker}'.")
return text
splits = text.split(begin_marker)
if len(splits) != 2:
raise RuntimeError(
f"Expected one marker '{begin_marker}'"
f'; found {text.count(begin_marker)}.'
)
before_begin, after_begin = splits
splits = after_begin.split(end_marker)
if len(splits) != 2:
raise RuntimeError(
f"Expected one marker '{end_marker}'"
f'; found {text.count(end_marker)}.'
)
_before_end, after_end = splits
if keep_markers:
replace_text = f'{begin_marker}{replace_text}{end_marker}'
return f'{before_begin}{replace_text}{after_end}'
def readfile(path: str | Path) -> str:
"""Read a utf-8 text file into a string."""
with open(path, encoding='utf-8') as infile:
return infile.read()
def writefile(path: str | Path, txt: str) -> None:
"""Write a string to a utf-8 text file."""
with open(path, 'w', encoding='utf-8') as outfile:
outfile.write(txt)
def replace_exact(
opstr: str, old: str, new: str, count: int = 1, label: str | None = None
) -> str:
"""Replace text ensuring that exactly x occurrences are replaced.
Useful when filtering data in some predefined way to ensure the original
has not changed.
"""
found = opstr.count(old)
label_str = f' in {label}' if label is not None else ''
if found != count:
raise RuntimeError(
f'Expected {count} string occurrence(s){label_str};'
f' found {found}. String: {repr(old)}'
)
return opstr.replace(old, new)
def get_files_hash(
filenames: Sequence[str | Path],
extrahash: str = '',
int_only: bool = False,
hashtype: Literal['md5', 'sha256'] = 'md5',
) -> str:
"""Return a hash for the given files."""
import hashlib
if not isinstance(filenames, list):
raise RuntimeError(f'Expected a list; got a {type(filenames)}.')
if TYPE_CHECKING:
# Help Mypy infer the right type for this.
hashobj = hashlib.md5()
else:
hashobj = getattr(hashlib, hashtype)()
for fname in filenames:
with open(fname, 'rb') as infile:
while True:
data = infile.read(2**20)
if not data:
break
hashobj.update(data)
hashobj.update(extrahash.encode())
if int_only:
return str(int.from_bytes(hashobj.digest(), byteorder='big'))
return hashobj.hexdigest()
def get_string_hash(
value: str,
int_only: bool = False,
hashtype: Literal['md5', 'sha256'] = 'md5',
) -> str:
"""Return a hash for the given files."""
import hashlib
if not isinstance(value, str):
raise TypeError('Expected a str.')
if TYPE_CHECKING:
# Help Mypy infer the right type for this.
hashobj = hashlib.md5()
else:
hashobj = getattr(hashlib, hashtype)()
hashobj.update(value.encode())
if int_only:
return str(int.from_bytes(hashobj.digest(), byteorder='big'))
return hashobj.hexdigest()