diff --git a/nixos/doc/manual/development/writing-nixos-tests.section.md b/nixos/doc/manual/development/writing-nixos-tests.section.md index bd588e2ba80b..5b08975e5ea4 100644 --- a/nixos/doc/manual/development/writing-nixos-tests.section.md +++ b/nixos/doc/manual/development/writing-nixos-tests.section.md @@ -121,8 +121,7 @@ and checks that the output is more-or-less correct: ```py machine.start() machine.wait_for_unit("default.target") -if not "Linux" in machine.succeed("uname"): - raise Exception("Wrong OS") +t.assertIn("Linux", machine.succeed("uname"), "Wrong OS") ``` The first line is technically unnecessary; machines are implicitly started @@ -134,6 +133,8 @@ starting them in parallel: start_all() ``` +Under the variable `t`, all assertions from [`unittest.TestCase`](https://docs.python.org/3/library/unittest.html) are available. + If the hostname of a node contains characters that can't be used in a Python variable name, those characters will be replaced with underscores in the variable name, so `nodes.machine-a` will be exposed diff --git a/nixos/lib/test-driver/src/test_driver/driver.py b/nixos/lib/test-driver/src/test_driver/driver.py index 6061c1bc09b8..8a878e9a7558 100644 --- a/nixos/lib/test-driver/src/test_driver/driver.py +++ b/nixos/lib/test-driver/src/test_driver/driver.py @@ -1,12 +1,15 @@ import os import re import signal +import sys import tempfile import threading +import traceback from collections.abc import Callable, Iterator from contextlib import AbstractContextManager, contextmanager from pathlib import Path from typing import Any +from unittest import TestCase from test_driver.logger import AbstractLogger from test_driver.machine import Machine, NixStartScript, retry @@ -38,6 +41,14 @@ def pythonize_name(name: str) -> str: return re.sub(r"^[^A-z_]|[^A-z0-9_]", "_", name) +class NixOSAssertionError(AssertionError): + pass + + +class Tester(TestCase): + failureException = NixOSAssertionError + + class Driver: """A handle to the driver that sets up the environment and runs the tests""" @@ -140,6 +151,7 @@ class Driver: serial_stdout_on=self.serial_stdout_on, polling_condition=self.polling_condition, Machine=Machine, # for typing + t=Tester(), ) machine_symbols = {pythonize_name(m.name): m for m in self.machines} # If there's exactly one machine, make it available under the name @@ -163,7 +175,31 @@ class Driver: """Run the test script""" with self.logger.nested("run the VM test script"): symbols = self.test_symbols() # call eagerly - exec(self.tests, symbols, None) + try: + exec(self.tests, symbols, None) + except NixOSAssertionError: + exc_type, exc, tb = sys.exc_info() + filtered = [ + frame + for frame in traceback.extract_tb(tb) + if frame.filename == "" + ] + + self.logger.log_test_error("Traceback (most recent call last):") + code = self.tests.splitlines() + for frame, line in zip(filtered, traceback.format_list(filtered)): + self.logger.log_test_error(line.rstrip()) + if lineno := frame.lineno: + self.logger.log_test_error( + f" {code[lineno - 1].strip()}", + ) + + self.logger.log_test_error("") # blank line for readability + exc_prefix = exc_type.__name__ if exc_type is not None else "Error" + for line in f"{exc_prefix}: {exc}".splitlines(): + self.logger.log_test_error(line) + + sys.exit(1) def run_tests(self) -> None: """Run the test script (for non-interactive test runs)""" diff --git a/nixos/lib/test-driver/src/test_driver/logger.py b/nixos/lib/test-driver/src/test_driver/logger.py index 564d39f4f055..fa195080fa2b 100644 --- a/nixos/lib/test-driver/src/test_driver/logger.py +++ b/nixos/lib/test-driver/src/test_driver/logger.py @@ -44,6 +44,10 @@ class AbstractLogger(ABC): def error(self, *args, **kwargs) -> None: # type: ignore pass + @abstractmethod + def log_test_error(self, *args, **kwargs) -> None: # type:ignore + pass + @abstractmethod def log_serial(self, message: str, machine: str) -> None: pass @@ -97,6 +101,9 @@ class JunitXMLLogger(AbstractLogger): self.tests[self.currentSubtest].stderr += args[0] + os.linesep self.tests[self.currentSubtest].failure = True + def log_test_error(self, *args, **kwargs) -> None: # type: ignore + self.error(*args, **kwargs) + def log_serial(self, message: str, machine: str) -> None: if not self._print_serial_logs: return @@ -156,6 +163,10 @@ class CompositeLogger(AbstractLogger): for logger in self.logger_list: logger.warning(*args, **kwargs) + def log_test_error(self, *args, **kwargs) -> None: # type: ignore + for logger in self.logger_list: + logger.log_test_error(*args, **kwargs) + def error(self, *args, **kwargs) -> None: # type: ignore for logger in self.logger_list: logger.error(*args, **kwargs) @@ -222,6 +233,11 @@ class TerminalLogger(AbstractLogger): self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL) + def log_test_error(self, *args, **kwargs) -> None: # type: ignore + prefix = Fore.RED + "!!! " + Style.RESET_ALL + # NOTE: using `warning` instead of `error` to ensure it does not exit after printing the first log + self.warning(f"{prefix}{args[0]}", *args[1:], **kwargs) + class XMLLogger(AbstractLogger): def __init__(self, outfile: str) -> None: @@ -261,6 +277,9 @@ class XMLLogger(AbstractLogger): def error(self, *args, **kwargs) -> None: # type: ignore self.log(*args, **kwargs) + def log_test_error(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + def log(self, message: str, attributes: dict[str, str] = {}) -> None: self.drain_log_queue() self.log_line(message, attributes) diff --git a/nixos/lib/test-script-prepend.py b/nixos/lib/test-script-prepend.py index 9d2efdf97303..31dad14ef8dd 100644 --- a/nixos/lib/test-script-prepend.py +++ b/nixos/lib/test-script-prepend.py @@ -8,6 +8,7 @@ from test_driver.logger import AbstractLogger from typing import Callable, Iterator, ContextManager, Optional, List, Dict, Any, Union from typing_extensions import Protocol from pathlib import Path +from unittest import TestCase class RetryProtocol(Protocol): @@ -51,3 +52,4 @@ join_all: Callable[[], None] serial_stdout_off: Callable[[], None] serial_stdout_on: Callable[[], None] polling_condition: PollingConditionProtocol +t: TestCase