1
0
Fork 0
mirror of https://github.com/NixOS/nixpkgs.git synced 2025-06-21 17:01:10 +03:00

nixos/test-driver: apply ruff check suggestions

This commit is contained in:
Nick Cao 2024-11-22 09:16:03 -05:00
parent 9069a281a7
commit b25360a7e5
No known key found for this signature in database
3 changed files with 48 additions and 45 deletions

View file

@ -3,9 +3,10 @@ import re
import signal import signal
import tempfile import tempfile
import threading import threading
from contextlib import contextmanager from collections.abc import Iterator
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Union from typing import Any, Callable, Optional, Union
from colorama import Fore, Style from colorama import Fore, Style
@ -44,17 +45,17 @@ class Driver:
and runs the tests""" and runs the tests"""
tests: str tests: str
vlans: List[VLan] vlans: list[VLan]
machines: List[Machine] machines: list[Machine]
polling_conditions: List[PollingCondition] polling_conditions: list[PollingCondition]
global_timeout: int global_timeout: int
race_timer: threading.Timer race_timer: threading.Timer
logger: AbstractLogger logger: AbstractLogger
def __init__( def __init__(
self, self,
start_scripts: List[str], start_scripts: list[str],
vlans: List[int], vlans: list[int],
tests: str, tests: str,
out_dir: Path, out_dir: Path,
logger: AbstractLogger, logger: AbstractLogger,
@ -73,7 +74,7 @@ class Driver:
vlans = list(set(vlans)) vlans = list(set(vlans))
self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans] self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]
def cmd(scripts: List[str]) -> Iterator[NixStartScript]: def cmd(scripts: list[str]) -> Iterator[NixStartScript]:
for s in scripts: for s in scripts:
yield NixStartScript(s) yield NixStartScript(s)
@ -119,7 +120,7 @@ class Driver:
self.logger.error(f'Test "{name}" failed with error: "{e}"') self.logger.error(f'Test "{name}" failed with error: "{e}"')
raise e raise e
def test_symbols(self) -> Dict[str, Any]: def test_symbols(self) -> dict[str, Any]:
@contextmanager @contextmanager
def subtest(name: str) -> Iterator[None]: def subtest(name: str) -> Iterator[None]:
return self.subtest(name) return self.subtest(name)
@ -277,7 +278,7 @@ class Driver:
*, *,
seconds_interval: float = 2.0, seconds_interval: float = 2.0,
description: Optional[str] = None, description: Optional[str] = None,
) -> Union[Callable[[Callable], ContextManager], ContextManager]: ) -> Union[Callable[[Callable], AbstractContextManager], AbstractContextManager]:
driver = self driver = self
class Poll: class Poll:

View file

@ -5,10 +5,11 @@ import sys
import time import time
import unicodedata import unicodedata
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from typing import Any, Dict, Iterator, List from typing import Any
from xml.sax.saxutils import XMLGenerator from xml.sax.saxutils import XMLGenerator
from xml.sax.xmlreader import AttributesImpl from xml.sax.xmlreader import AttributesImpl
@ -18,17 +19,17 @@ from junit_xml import TestCase, TestSuite
class AbstractLogger(ABC): class AbstractLogger(ABC):
@abstractmethod @abstractmethod
def log(self, message: str, attributes: Dict[str, str] = {}) -> None: def log(self, message: str, attributes: dict[str, str] = {}) -> None:
pass pass
@abstractmethod @abstractmethod
@contextmanager @contextmanager
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
pass pass
@abstractmethod @abstractmethod
@contextmanager @contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
pass pass
@abstractmethod @abstractmethod
@ -68,11 +69,11 @@ class JunitXMLLogger(AbstractLogger):
self._print_serial_logs = True self._print_serial_logs = True
atexit.register(self.close) atexit.register(self.close)
def log(self, message: str, attributes: Dict[str, str] = {}) -> None: def log(self, message: str, attributes: dict[str, str] = {}) -> None:
self.tests[self.currentSubtest].stdout += message + os.linesep self.tests[self.currentSubtest].stdout += message + os.linesep
@contextmanager @contextmanager
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
old_test = self.currentSubtest old_test = self.currentSubtest
self.tests.setdefault(name, self.TestCaseState()) self.tests.setdefault(name, self.TestCaseState())
self.currentSubtest = name self.currentSubtest = name
@ -82,7 +83,7 @@ class JunitXMLLogger(AbstractLogger):
self.currentSubtest = old_test self.currentSubtest = old_test
@contextmanager @contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
self.log(message) self.log(message)
yield yield
@ -123,25 +124,25 @@ class JunitXMLLogger(AbstractLogger):
class CompositeLogger(AbstractLogger): class CompositeLogger(AbstractLogger):
def __init__(self, logger_list: List[AbstractLogger]) -> None: def __init__(self, logger_list: list[AbstractLogger]) -> None:
self.logger_list = logger_list self.logger_list = logger_list
def add_logger(self, logger: AbstractLogger) -> None: def add_logger(self, logger: AbstractLogger) -> None:
self.logger_list.append(logger) self.logger_list.append(logger)
def log(self, message: str, attributes: Dict[str, str] = {}) -> None: def log(self, message: str, attributes: dict[str, str] = {}) -> None:
for logger in self.logger_list: for logger in self.logger_list:
logger.log(message, attributes) logger.log(message, attributes)
@contextmanager @contextmanager
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
with ExitStack() as stack: with ExitStack() as stack:
for logger in self.logger_list: for logger in self.logger_list:
stack.enter_context(logger.subtest(name, attributes)) stack.enter_context(logger.subtest(name, attributes))
yield yield
@contextmanager @contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
with ExitStack() as stack: with ExitStack() as stack:
for logger in self.logger_list: for logger in self.logger_list:
stack.enter_context(logger.nested(message, attributes)) stack.enter_context(logger.nested(message, attributes))
@ -173,7 +174,7 @@ class TerminalLogger(AbstractLogger):
def __init__(self) -> None: def __init__(self) -> None:
self._print_serial_logs = True self._print_serial_logs = True
def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str: def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
if "machine" in attributes: if "machine" in attributes:
return f"{attributes['machine']}: {message}" return f"{attributes['machine']}: {message}"
return message return message
@ -182,16 +183,16 @@ class TerminalLogger(AbstractLogger):
def _eprint(*args: object, **kwargs: Any) -> None: def _eprint(*args: object, **kwargs: Any) -> None:
print(*args, file=sys.stderr, **kwargs) print(*args, file=sys.stderr, **kwargs)
def log(self, message: str, attributes: Dict[str, str] = {}) -> None: def log(self, message: str, attributes: dict[str, str] = {}) -> None:
self._eprint(self.maybe_prefix(message, attributes)) self._eprint(self.maybe_prefix(message, attributes))
@contextmanager @contextmanager
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
with self.nested("subtest: " + name, attributes): with self.nested("subtest: " + name, attributes):
yield yield
@contextmanager @contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
self._eprint( self._eprint(
self.maybe_prefix( self.maybe_prefix(
Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
@ -241,12 +242,12 @@ class XMLLogger(AbstractLogger):
def sanitise(self, message: str) -> str: def sanitise(self, message: str) -> str:
return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C") return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str: def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
if "machine" in attributes: if "machine" in attributes:
return f"{attributes['machine']}: {message}" return f"{attributes['machine']}: {message}"
return message return message
def log_line(self, message: str, attributes: Dict[str, str]) -> None: def log_line(self, message: str, attributes: dict[str, str]) -> None:
self.xml.startElement("line", attrs=AttributesImpl(attributes)) self.xml.startElement("line", attrs=AttributesImpl(attributes))
self.xml.characters(message) self.xml.characters(message)
self.xml.endElement("line") self.xml.endElement("line")
@ -260,7 +261,7 @@ class XMLLogger(AbstractLogger):
def error(self, *args, **kwargs) -> None: # type: ignore def error(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs) self.log(*args, **kwargs)
def log(self, message: str, attributes: Dict[str, str] = {}) -> None: def log(self, message: str, attributes: dict[str, str] = {}) -> None:
self.drain_log_queue() self.drain_log_queue()
self.log_line(message, attributes) self.log_line(message, attributes)
@ -273,7 +274,7 @@ class XMLLogger(AbstractLogger):
self.enqueue({"msg": message, "machine": machine, "type": "serial"}) self.enqueue({"msg": message, "machine": machine, "type": "serial"})
def enqueue(self, item: Dict[str, str]) -> None: def enqueue(self, item: dict[str, str]) -> None:
self.queue.put(item) self.queue.put(item)
def drain_log_queue(self) -> None: def drain_log_queue(self) -> None:
@ -287,12 +288,12 @@ class XMLLogger(AbstractLogger):
pass pass
@contextmanager @contextmanager
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
with self.nested("subtest: " + name, attributes): with self.nested("subtest: " + name, attributes):
yield yield
@contextmanager @contextmanager
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
self.xml.startElement("nest", attrs=AttributesImpl({})) self.xml.startElement("nest", attrs=AttributesImpl({}))
self.xml.startElement("head", attrs=AttributesImpl(attributes)) self.xml.startElement("head", attrs=AttributesImpl(attributes))
self.xml.characters(message) self.xml.characters(message)

View file

@ -12,10 +12,11 @@ import sys
import tempfile import tempfile
import threading import threading
import time import time
from collections.abc import Iterable
from contextlib import _GeneratorContextManager, nullcontext from contextlib import _GeneratorContextManager, nullcontext
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from typing import Any, Callable, Optional
from test_driver.logger import AbstractLogger from test_driver.logger import AbstractLogger
@ -91,7 +92,7 @@ def make_command(args: list) -> str:
def _perform_ocr_on_screenshot( def _perform_ocr_on_screenshot(
screenshot_path: str, model_ids: Iterable[int] screenshot_path: str, model_ids: Iterable[int]
) -> List[str]: ) -> list[str]:
if shutil.which("tesseract") is None: if shutil.which("tesseract") is None:
raise Exception("OCR requested but enableOCR is false") raise Exception("OCR requested but enableOCR is false")
@ -260,7 +261,7 @@ class Machine:
# Store last serial console lines for use # Store last serial console lines for use
# of wait_for_console_text # of wait_for_console_text
last_lines: Queue = Queue() last_lines: Queue = Queue()
callbacks: List[Callable] callbacks: list[Callable]
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Machine '{self.name}'>" return f"<Machine '{self.name}'>"
@ -273,7 +274,7 @@ class Machine:
logger: AbstractLogger, logger: AbstractLogger,
name: str = "machine", name: str = "machine",
keep_vm_state: bool = False, keep_vm_state: bool = False,
callbacks: Optional[List[Callable]] = None, callbacks: Optional[list[Callable]] = None,
) -> None: ) -> None:
self.out_dir = out_dir self.out_dir = out_dir
self.tmp_dir = tmp_dir self.tmp_dir = tmp_dir
@ -314,7 +315,7 @@ class Machine:
def log_serial(self, msg: str) -> None: def log_serial(self, msg: str) -> None:
self.logger.log_serial(msg, self.name) self.logger.log_serial(msg, self.name)
def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager: def nested(self, msg: str, attrs: dict[str, str] = {}) -> _GeneratorContextManager:
my_attrs = {"machine": self.name} my_attrs = {"machine": self.name}
my_attrs.update(attrs) my_attrs.update(attrs)
return self.logger.nested(msg, my_attrs) return self.logger.nested(msg, my_attrs)
@ -373,7 +374,7 @@ class Machine:
): ):
retry(check_active, timeout) retry(check_active, timeout)
def get_unit_info(self, unit: str, user: Optional[str] = None) -> Dict[str, str]: def get_unit_info(self, unit: str, user: Optional[str] = None) -> dict[str, str]:
status, lines = self.systemctl(f'--no-pager show "{unit}"', user) status, lines = self.systemctl(f'--no-pager show "{unit}"', user)
if status != 0: if status != 0:
raise Exception( raise Exception(
@ -384,7 +385,7 @@ class Machine:
line_pattern = re.compile(r"^([^=]+)=(.*)$") line_pattern = re.compile(r"^([^=]+)=(.*)$")
def tuple_from_line(line: str) -> Tuple[str, str]: def tuple_from_line(line: str) -> tuple[str, str]:
match = line_pattern.match(line) match = line_pattern.match(line)
assert match is not None assert match is not None
return match[1], match[2] return match[1], match[2]
@ -424,7 +425,7 @@ class Machine:
assert match[1] == property, invalid_output_message assert match[1] == property, invalid_output_message
return match[2] return match[2]
def systemctl(self, q: str, user: Optional[str] = None) -> Tuple[int, str]: def systemctl(self, q: str, user: Optional[str] = None) -> tuple[int, str]:
""" """
Runs `systemctl` commands with optional support for Runs `systemctl` commands with optional support for
`systemctl --user` `systemctl --user`
@ -481,7 +482,7 @@ class Machine:
check_return: bool = True, check_return: bool = True,
check_output: bool = True, check_output: bool = True,
timeout: Optional[int] = 900, timeout: Optional[int] = 900,
) -> Tuple[int, str]: ) -> tuple[int, str]:
""" """
Execute a shell command, returning a list `(status, stdout)`. Execute a shell command, returning a list `(status, stdout)`.
@ -798,10 +799,10 @@ class Machine:
with self.nested(f"waiting for TCP port {port} on {addr} to be closed"): with self.nested(f"waiting for TCP port {port} on {addr} to be closed"):
retry(port_is_closed, timeout) retry(port_is_closed, timeout)
def start_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]: def start_job(self, jobname: str, user: Optional[str] = None) -> tuple[int, str]:
return self.systemctl(f"start {jobname}", user) return self.systemctl(f"start {jobname}", user)
def stop_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]: def stop_job(self, jobname: str, user: Optional[str] = None) -> tuple[int, str]:
return self.systemctl(f"stop {jobname}", user) return self.systemctl(f"stop {jobname}", user)
def wait_for_job(self, jobname: str) -> None: def wait_for_job(self, jobname: str) -> None:
@ -942,13 +943,13 @@ class Machine:
"""Debugging: Dump the contents of the TTY<n>""" """Debugging: Dump the contents of the TTY<n>"""
self.execute(f"fold -w 80 /dev/vcs{tty} | systemd-cat") self.execute(f"fold -w 80 /dev/vcs{tty} | systemd-cat")
def _get_screen_text_variants(self, model_ids: Iterable[int]) -> List[str]: def _get_screen_text_variants(self, model_ids: Iterable[int]) -> list[str]:
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
screenshot_path = os.path.join(tmpdir, "ppm") screenshot_path = os.path.join(tmpdir, "ppm")
self.send_monitor_command(f"screendump {screenshot_path}") self.send_monitor_command(f"screendump {screenshot_path}")
return _perform_ocr_on_screenshot(screenshot_path, model_ids) return _perform_ocr_on_screenshot(screenshot_path, model_ids)
def get_screen_text_variants(self) -> List[str]: def get_screen_text_variants(self) -> list[str]:
""" """
Return a list of different interpretations of what is currently Return a list of different interpretations of what is currently
visible on the machine's screen using optical character visible on the machine's screen using optical character
@ -1168,7 +1169,7 @@ class Machine:
with self.nested("waiting for the X11 server"): with self.nested("waiting for the X11 server"):
retry(check_x, timeout) retry(check_x, timeout)
def get_window_names(self) -> List[str]: def get_window_names(self) -> list[str]:
return self.succeed( return self.succeed(
r"xwininfo -root -tree | sed 's/.*0x[0-9a-f]* \"\([^\"]*\)\".*/\1/; t; d'" r"xwininfo -root -tree | sed 's/.*0x[0-9a-f]* \"\([^\"]*\)\".*/\1/; t; d'"
).splitlines() ).splitlines()