1
0
Fork 0
mirror of https://github.com/NixOS/nixpkgs.git synced 2025-06-21 17:01:10 +03:00

nixos/test-driver: apply ruff check suggestions

This commit is contained in:
Nick Cao 2024-11-22 09:16:03 -05:00
parent 9069a281a7
commit b25360a7e5
No known key found for this signature in database
3 changed files with 48 additions and 45 deletions

View file

@ -3,9 +3,10 @@ import re
import signal
import tempfile
import threading
from contextlib import contextmanager
from collections.abc import Iterator
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Union
from typing import Any, Callable, Optional, Union
from colorama import Fore, Style
@ -44,17 +45,17 @@ class Driver:
and runs the tests"""
tests: str
vlans: List[VLan]
machines: List[Machine]
polling_conditions: List[PollingCondition]
vlans: list[VLan]
machines: list[Machine]
polling_conditions: list[PollingCondition]
global_timeout: int
race_timer: threading.Timer
logger: AbstractLogger
def __init__(
self,
start_scripts: List[str],
vlans: List[int],
start_scripts: list[str],
vlans: list[int],
tests: str,
out_dir: Path,
logger: AbstractLogger,
@ -73,7 +74,7 @@ class Driver:
vlans = list(set(vlans))
self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]
def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
def cmd(scripts: list[str]) -> Iterator[NixStartScript]:
for s in scripts:
yield NixStartScript(s)
@ -119,7 +120,7 @@ class Driver:
self.logger.error(f'Test "{name}" failed with error: "{e}"')
raise e
def test_symbols(self) -> Dict[str, Any]:
def test_symbols(self) -> dict[str, Any]:
@contextmanager
def subtest(name: str) -> Iterator[None]:
return self.subtest(name)
@ -277,7 +278,7 @@ class Driver:
*,
seconds_interval: float = 2.0,
description: Optional[str] = None,
) -> Union[Callable[[Callable], ContextManager], ContextManager]:
) -> Union[Callable[[Callable], AbstractContextManager], AbstractContextManager]:
driver = self
class Poll:

View file

@ -5,10 +5,11 @@ import sys
import time
import unicodedata
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import ExitStack, contextmanager
from pathlib import Path
from queue import Empty, Queue
from typing import Any, Dict, Iterator, List
from typing import Any
from xml.sax.saxutils import XMLGenerator
from xml.sax.xmlreader import AttributesImpl
@ -18,17 +19,17 @@ from junit_xml import TestCase, TestSuite
class AbstractLogger(ABC):
@abstractmethod
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
pass
@abstractmethod
@contextmanager
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
pass
@abstractmethod
@contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
pass
@abstractmethod
@ -68,11 +69,11 @@ class JunitXMLLogger(AbstractLogger):
self._print_serial_logs = True
atexit.register(self.close)
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
self.tests[self.currentSubtest].stdout += message + os.linesep
@contextmanager
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
old_test = self.currentSubtest
self.tests.setdefault(name, self.TestCaseState())
self.currentSubtest = name
@ -82,7 +83,7 @@ class JunitXMLLogger(AbstractLogger):
self.currentSubtest = old_test
@contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
self.log(message)
yield
@ -123,25 +124,25 @@ class JunitXMLLogger(AbstractLogger):
class CompositeLogger(AbstractLogger):
def __init__(self, logger_list: List[AbstractLogger]) -> None:
def __init__(self, logger_list: list[AbstractLogger]) -> None:
self.logger_list = logger_list
def add_logger(self, logger: AbstractLogger) -> None:
self.logger_list.append(logger)
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
for logger in self.logger_list:
logger.log(message, attributes)
@contextmanager
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
with ExitStack() as stack:
for logger in self.logger_list:
stack.enter_context(logger.subtest(name, attributes))
yield
@contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
with ExitStack() as stack:
for logger in self.logger_list:
stack.enter_context(logger.nested(message, attributes))
@ -173,7 +174,7 @@ class TerminalLogger(AbstractLogger):
def __init__(self) -> None:
self._print_serial_logs = True
def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
if "machine" in attributes:
return f"{attributes['machine']}: {message}"
return message
@ -182,16 +183,16 @@ class TerminalLogger(AbstractLogger):
def _eprint(*args: object, **kwargs: Any) -> None:
print(*args, file=sys.stderr, **kwargs)
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
self._eprint(self.maybe_prefix(message, attributes))
@contextmanager
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
with self.nested("subtest: " + name, attributes):
yield
@contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
self._eprint(
self.maybe_prefix(
Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
@ -241,12 +242,12 @@ class XMLLogger(AbstractLogger):
def sanitise(self, message: str) -> str:
return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
if "machine" in attributes:
return f"{attributes['machine']}: {message}"
return message
def log_line(self, message: str, attributes: Dict[str, str]) -> None:
def log_line(self, message: str, attributes: dict[str, str]) -> None:
self.xml.startElement("line", attrs=AttributesImpl(attributes))
self.xml.characters(message)
self.xml.endElement("line")
@ -260,7 +261,7 @@ class XMLLogger(AbstractLogger):
def error(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
self.drain_log_queue()
self.log_line(message, attributes)
@ -273,7 +274,7 @@ class XMLLogger(AbstractLogger):
self.enqueue({"msg": message, "machine": machine, "type": "serial"})
def enqueue(self, item: Dict[str, str]) -> None:
def enqueue(self, item: dict[str, str]) -> None:
self.queue.put(item)
def drain_log_queue(self) -> None:
@ -287,12 +288,12 @@ class XMLLogger(AbstractLogger):
pass
@contextmanager
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
with self.nested("subtest: " + name, attributes):
yield
@contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
self.xml.startElement("nest", attrs=AttributesImpl({}))
self.xml.startElement("head", attrs=AttributesImpl(attributes))
self.xml.characters(message)

View file

@ -12,10 +12,11 @@ import sys
import tempfile
import threading
import time
from collections.abc import Iterable
from contextlib import _GeneratorContextManager, nullcontext
from pathlib import Path
from queue import Queue
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Optional
from test_driver.logger import AbstractLogger
@ -91,7 +92,7 @@ def make_command(args: list) -> str:
def _perform_ocr_on_screenshot(
screenshot_path: str, model_ids: Iterable[int]
) -> List[str]:
) -> list[str]:
if shutil.which("tesseract") is None:
raise Exception("OCR requested but enableOCR is false")
@ -260,7 +261,7 @@ class Machine:
# Store last serial console lines for use
# of wait_for_console_text
last_lines: Queue = Queue()
callbacks: List[Callable]
callbacks: list[Callable]
def __repr__(self) -> str:
return f"<Machine '{self.name}'>"
@ -273,7 +274,7 @@ class Machine:
logger: AbstractLogger,
name: str = "machine",
keep_vm_state: bool = False,
callbacks: Optional[List[Callable]] = None,
callbacks: Optional[list[Callable]] = None,
) -> None:
self.out_dir = out_dir
self.tmp_dir = tmp_dir
@ -314,7 +315,7 @@ class Machine:
def log_serial(self, msg: str) -> None:
self.logger.log_serial(msg, self.name)
def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
def nested(self, msg: str, attrs: dict[str, str] = {}) -> _GeneratorContextManager:
my_attrs = {"machine": self.name}
my_attrs.update(attrs)
return self.logger.nested(msg, my_attrs)
@ -373,7 +374,7 @@ class Machine:
):
retry(check_active, timeout)
def get_unit_info(self, unit: str, user: Optional[str] = None) -> Dict[str, str]:
def get_unit_info(self, unit: str, user: Optional[str] = None) -> dict[str, str]:
status, lines = self.systemctl(f'--no-pager show "{unit}"', user)
if status != 0:
raise Exception(
@ -384,7 +385,7 @@ class Machine:
line_pattern = re.compile(r"^([^=]+)=(.*)$")
def tuple_from_line(line: str) -> Tuple[str, str]:
def tuple_from_line(line: str) -> tuple[str, str]:
match = line_pattern.match(line)
assert match is not None
return match[1], match[2]
@ -424,7 +425,7 @@ class Machine:
assert match[1] == property, invalid_output_message
return match[2]
def systemctl(self, q: str, user: Optional[str] = None) -> Tuple[int, str]:
def systemctl(self, q: str, user: Optional[str] = None) -> tuple[int, str]:
"""
Runs `systemctl` commands with optional support for
`systemctl --user`
@ -481,7 +482,7 @@ class Machine:
check_return: bool = True,
check_output: bool = True,
timeout: Optional[int] = 900,
) -> Tuple[int, str]:
) -> tuple[int, str]:
"""
Execute a shell command, returning a list `(status, stdout)`.
@ -798,10 +799,10 @@ class Machine:
with self.nested(f"waiting for TCP port {port} on {addr} to be closed"):
retry(port_is_closed, timeout)
def start_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]:
def start_job(self, jobname: str, user: Optional[str] = None) -> tuple[int, str]:
return self.systemctl(f"start {jobname}", user)
def stop_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]:
def stop_job(self, jobname: str, user: Optional[str] = None) -> tuple[int, str]:
return self.systemctl(f"stop {jobname}", user)
def wait_for_job(self, jobname: str) -> None:
@ -942,13 +943,13 @@ class Machine:
"""Debugging: Dump the contents of the TTY<n>"""
self.execute(f"fold -w 80 /dev/vcs{tty} | systemd-cat")
def _get_screen_text_variants(self, model_ids: Iterable[int]) -> List[str]:
def _get_screen_text_variants(self, model_ids: Iterable[int]) -> list[str]:
with tempfile.TemporaryDirectory() as tmpdir:
screenshot_path = os.path.join(tmpdir, "ppm")
self.send_monitor_command(f"screendump {screenshot_path}")
return _perform_ocr_on_screenshot(screenshot_path, model_ids)
def get_screen_text_variants(self) -> List[str]:
def get_screen_text_variants(self) -> list[str]:
"""
Return a list of different interpretations of what is currently
visible on the machine's screen using optical character
@ -1168,7 +1169,7 @@ class Machine:
with self.nested("waiting for the X11 server"):
retry(check_x, timeout)
def get_window_names(self) -> List[str]:
def get_window_names(self) -> list[str]:
return self.succeed(
r"xwininfo -root -tree | sed 's/.*0x[0-9a-f]* \"\([^\"]*\)\".*/\1/; t; d'"
).splitlines()