mirror of
https://github.com/NixOS/nixpkgs.git
synced 2025-06-21 17:01:10 +03:00
nixos/test-driver: apply ruff check suggestions
This commit is contained in:
parent
9069a281a7
commit
b25360a7e5
3 changed files with 48 additions and 45 deletions
|
@ -3,9 +3,10 @@ import re
|
|||
import signal
|
||||
import tempfile
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from collections.abc import Iterator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
||||
|
@ -44,17 +45,17 @@ class Driver:
|
|||
and runs the tests"""
|
||||
|
||||
tests: str
|
||||
vlans: List[VLan]
|
||||
machines: List[Machine]
|
||||
polling_conditions: List[PollingCondition]
|
||||
vlans: list[VLan]
|
||||
machines: list[Machine]
|
||||
polling_conditions: list[PollingCondition]
|
||||
global_timeout: int
|
||||
race_timer: threading.Timer
|
||||
logger: AbstractLogger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_scripts: List[str],
|
||||
vlans: List[int],
|
||||
start_scripts: list[str],
|
||||
vlans: list[int],
|
||||
tests: str,
|
||||
out_dir: Path,
|
||||
logger: AbstractLogger,
|
||||
|
@ -73,7 +74,7 @@ class Driver:
|
|||
vlans = list(set(vlans))
|
||||
self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]
|
||||
|
||||
def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
|
||||
def cmd(scripts: list[str]) -> Iterator[NixStartScript]:
|
||||
for s in scripts:
|
||||
yield NixStartScript(s)
|
||||
|
||||
|
@ -119,7 +120,7 @@ class Driver:
|
|||
self.logger.error(f'Test "{name}" failed with error: "{e}"')
|
||||
raise e
|
||||
|
||||
def test_symbols(self) -> Dict[str, Any]:
|
||||
def test_symbols(self) -> dict[str, Any]:
|
||||
@contextmanager
|
||||
def subtest(name: str) -> Iterator[None]:
|
||||
return self.subtest(name)
|
||||
|
@ -277,7 +278,7 @@ class Driver:
|
|||
*,
|
||||
seconds_interval: float = 2.0,
|
||||
description: Optional[str] = None,
|
||||
) -> Union[Callable[[Callable], ContextManager], ContextManager]:
|
||||
) -> Union[Callable[[Callable], AbstractContextManager], AbstractContextManager]:
|
||||
driver = self
|
||||
|
||||
class Poll:
|
||||
|
|
|
@ -5,10 +5,11 @@ import sys
|
|||
import time
|
||||
import unicodedata
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from typing import Any, Dict, Iterator, List
|
||||
from typing import Any
|
||||
from xml.sax.saxutils import XMLGenerator
|
||||
from xml.sax.xmlreader import AttributesImpl
|
||||
|
||||
|
@ -18,17 +19,17 @@ from junit_xml import TestCase, TestSuite
|
|||
|
||||
class AbstractLogger(ABC):
|
||||
@abstractmethod
|
||||
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
|
||||
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -68,11 +69,11 @@ class JunitXMLLogger(AbstractLogger):
|
|||
self._print_serial_logs = True
|
||||
atexit.register(self.close)
|
||||
|
||||
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
|
||||
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
|
||||
self.tests[self.currentSubtest].stdout += message + os.linesep
|
||||
|
||||
@contextmanager
|
||||
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
|
||||
old_test = self.currentSubtest
|
||||
self.tests.setdefault(name, self.TestCaseState())
|
||||
self.currentSubtest = name
|
||||
|
@ -82,7 +83,7 @@ class JunitXMLLogger(AbstractLogger):
|
|||
self.currentSubtest = old_test
|
||||
|
||||
@contextmanager
|
||||
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
|
||||
self.log(message)
|
||||
yield
|
||||
|
||||
|
@ -123,25 +124,25 @@ class JunitXMLLogger(AbstractLogger):
|
|||
|
||||
|
||||
class CompositeLogger(AbstractLogger):
|
||||
def __init__(self, logger_list: List[AbstractLogger]) -> None:
|
||||
def __init__(self, logger_list: list[AbstractLogger]) -> None:
|
||||
self.logger_list = logger_list
|
||||
|
||||
def add_logger(self, logger: AbstractLogger) -> None:
|
||||
self.logger_list.append(logger)
|
||||
|
||||
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
|
||||
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
|
||||
for logger in self.logger_list:
|
||||
logger.log(message, attributes)
|
||||
|
||||
@contextmanager
|
||||
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
|
||||
with ExitStack() as stack:
|
||||
for logger in self.logger_list:
|
||||
stack.enter_context(logger.subtest(name, attributes))
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
|
||||
with ExitStack() as stack:
|
||||
for logger in self.logger_list:
|
||||
stack.enter_context(logger.nested(message, attributes))
|
||||
|
@ -173,7 +174,7 @@ class TerminalLogger(AbstractLogger):
|
|||
def __init__(self) -> None:
|
||||
self._print_serial_logs = True
|
||||
|
||||
def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
|
||||
def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
|
||||
if "machine" in attributes:
|
||||
return f"{attributes['machine']}: {message}"
|
||||
return message
|
||||
|
@ -182,16 +183,16 @@ class TerminalLogger(AbstractLogger):
|
|||
def _eprint(*args: object, **kwargs: Any) -> None:
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
|
||||
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
|
||||
self._eprint(self.maybe_prefix(message, attributes))
|
||||
|
||||
@contextmanager
|
||||
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
|
||||
with self.nested("subtest: " + name, attributes):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
|
||||
self._eprint(
|
||||
self.maybe_prefix(
|
||||
Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
|
||||
|
@ -241,12 +242,12 @@ class XMLLogger(AbstractLogger):
|
|||
def sanitise(self, message: str) -> str:
|
||||
return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
|
||||
|
||||
def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
|
||||
def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
|
||||
if "machine" in attributes:
|
||||
return f"{attributes['machine']}: {message}"
|
||||
return message
|
||||
|
||||
def log_line(self, message: str, attributes: Dict[str, str]) -> None:
|
||||
def log_line(self, message: str, attributes: dict[str, str]) -> None:
|
||||
self.xml.startElement("line", attrs=AttributesImpl(attributes))
|
||||
self.xml.characters(message)
|
||||
self.xml.endElement("line")
|
||||
|
@ -260,7 +261,7 @@ class XMLLogger(AbstractLogger):
|
|||
def error(self, *args, **kwargs) -> None: # type: ignore
|
||||
self.log(*args, **kwargs)
|
||||
|
||||
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
|
||||
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
|
||||
self.drain_log_queue()
|
||||
self.log_line(message, attributes)
|
||||
|
||||
|
@ -273,7 +274,7 @@ class XMLLogger(AbstractLogger):
|
|||
|
||||
self.enqueue({"msg": message, "machine": machine, "type": "serial"})
|
||||
|
||||
def enqueue(self, item: Dict[str, str]) -> None:
|
||||
def enqueue(self, item: dict[str, str]) -> None:
|
||||
self.queue.put(item)
|
||||
|
||||
def drain_log_queue(self) -> None:
|
||||
|
@ -287,12 +288,12 @@ class XMLLogger(AbstractLogger):
|
|||
pass
|
||||
|
||||
@contextmanager
|
||||
def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
|
||||
with self.nested("subtest: " + name, attributes):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
|
||||
def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
|
||||
self.xml.startElement("nest", attrs=AttributesImpl({}))
|
||||
self.xml.startElement("head", attrs=AttributesImpl(attributes))
|
||||
self.xml.characters(message)
|
||||
|
|
|
@ -12,10 +12,11 @@ import sys
|
|||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from contextlib import _GeneratorContextManager, nullcontext
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from test_driver.logger import AbstractLogger
|
||||
|
||||
|
@ -91,7 +92,7 @@ def make_command(args: list) -> str:
|
|||
|
||||
def _perform_ocr_on_screenshot(
|
||||
screenshot_path: str, model_ids: Iterable[int]
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
if shutil.which("tesseract") is None:
|
||||
raise Exception("OCR requested but enableOCR is false")
|
||||
|
||||
|
@ -260,7 +261,7 @@ class Machine:
|
|||
# Store last serial console lines for use
|
||||
# of wait_for_console_text
|
||||
last_lines: Queue = Queue()
|
||||
callbacks: List[Callable]
|
||||
callbacks: list[Callable]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Machine '{self.name}'>"
|
||||
|
@ -273,7 +274,7 @@ class Machine:
|
|||
logger: AbstractLogger,
|
||||
name: str = "machine",
|
||||
keep_vm_state: bool = False,
|
||||
callbacks: Optional[List[Callable]] = None,
|
||||
callbacks: Optional[list[Callable]] = None,
|
||||
) -> None:
|
||||
self.out_dir = out_dir
|
||||
self.tmp_dir = tmp_dir
|
||||
|
@ -314,7 +315,7 @@ class Machine:
|
|||
def log_serial(self, msg: str) -> None:
|
||||
self.logger.log_serial(msg, self.name)
|
||||
|
||||
def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
|
||||
def nested(self, msg: str, attrs: dict[str, str] = {}) -> _GeneratorContextManager:
|
||||
my_attrs = {"machine": self.name}
|
||||
my_attrs.update(attrs)
|
||||
return self.logger.nested(msg, my_attrs)
|
||||
|
@ -373,7 +374,7 @@ class Machine:
|
|||
):
|
||||
retry(check_active, timeout)
|
||||
|
||||
def get_unit_info(self, unit: str, user: Optional[str] = None) -> Dict[str, str]:
|
||||
def get_unit_info(self, unit: str, user: Optional[str] = None) -> dict[str, str]:
|
||||
status, lines = self.systemctl(f'--no-pager show "{unit}"', user)
|
||||
if status != 0:
|
||||
raise Exception(
|
||||
|
@ -384,7 +385,7 @@ class Machine:
|
|||
|
||||
line_pattern = re.compile(r"^([^=]+)=(.*)$")
|
||||
|
||||
def tuple_from_line(line: str) -> Tuple[str, str]:
|
||||
def tuple_from_line(line: str) -> tuple[str, str]:
|
||||
match = line_pattern.match(line)
|
||||
assert match is not None
|
||||
return match[1], match[2]
|
||||
|
@ -424,7 +425,7 @@ class Machine:
|
|||
assert match[1] == property, invalid_output_message
|
||||
return match[2]
|
||||
|
||||
def systemctl(self, q: str, user: Optional[str] = None) -> Tuple[int, str]:
|
||||
def systemctl(self, q: str, user: Optional[str] = None) -> tuple[int, str]:
|
||||
"""
|
||||
Runs `systemctl` commands with optional support for
|
||||
`systemctl --user`
|
||||
|
@ -481,7 +482,7 @@ class Machine:
|
|||
check_return: bool = True,
|
||||
check_output: bool = True,
|
||||
timeout: Optional[int] = 900,
|
||||
) -> Tuple[int, str]:
|
||||
) -> tuple[int, str]:
|
||||
"""
|
||||
Execute a shell command, returning a list `(status, stdout)`.
|
||||
|
||||
|
@ -798,10 +799,10 @@ class Machine:
|
|||
with self.nested(f"waiting for TCP port {port} on {addr} to be closed"):
|
||||
retry(port_is_closed, timeout)
|
||||
|
||||
def start_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]:
|
||||
def start_job(self, jobname: str, user: Optional[str] = None) -> tuple[int, str]:
|
||||
return self.systemctl(f"start {jobname}", user)
|
||||
|
||||
def stop_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]:
|
||||
def stop_job(self, jobname: str, user: Optional[str] = None) -> tuple[int, str]:
|
||||
return self.systemctl(f"stop {jobname}", user)
|
||||
|
||||
def wait_for_job(self, jobname: str) -> None:
|
||||
|
@ -942,13 +943,13 @@ class Machine:
|
|||
"""Debugging: Dump the contents of the TTY<n>"""
|
||||
self.execute(f"fold -w 80 /dev/vcs{tty} | systemd-cat")
|
||||
|
||||
def _get_screen_text_variants(self, model_ids: Iterable[int]) -> List[str]:
|
||||
def _get_screen_text_variants(self, model_ids: Iterable[int]) -> list[str]:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
screenshot_path = os.path.join(tmpdir, "ppm")
|
||||
self.send_monitor_command(f"screendump {screenshot_path}")
|
||||
return _perform_ocr_on_screenshot(screenshot_path, model_ids)
|
||||
|
||||
def get_screen_text_variants(self) -> List[str]:
|
||||
def get_screen_text_variants(self) -> list[str]:
|
||||
"""
|
||||
Return a list of different interpretations of what is currently
|
||||
visible on the machine's screen using optical character
|
||||
|
@ -1168,7 +1169,7 @@ class Machine:
|
|||
with self.nested("waiting for the X11 server"):
|
||||
retry(check_x, timeout)
|
||||
|
||||
def get_window_names(self) -> List[str]:
|
||||
def get_window_names(self) -> list[str]:
|
||||
return self.succeed(
|
||||
r"xwininfo -root -tree | sed 's/.*0x[0-9a-f]* \"\([^\"]*\)\".*/\1/; t; d'"
|
||||
).splitlines()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue