diff --git a/nixos/lib/test-driver/src/test_driver/machine/__init__.py b/nixos/lib/test-driver/src/test_driver/machine/__init__.py index 1b9dd1262ce6..7ad30178394e 100644 --- a/nixos/lib/test-driver/src/test_driver/machine/__init__.py +++ b/nixos/lib/test-driver/src/test_driver/machine/__init__.py @@ -13,8 +13,8 @@ import sys import tempfile import threading import time -from collections.abc import Callable, Iterable -from contextlib import _GeneratorContextManager, nullcontext +from collections.abc import Callable, Generator +from contextlib import _GeneratorContextManager, contextmanager, nullcontext from pathlib import Path from queue import Queue from typing import Any @@ -22,6 +22,7 @@ from typing import Any from test_driver.errors import MachineError, RequestedAssertionFailed from test_driver.logger import AbstractLogger +from .ocr import perform_ocr_on_screenshot, perform_ocr_variants_on_screenshot from .qmp import QMPSession CHAR_TO_KEY = { @@ -92,84 +93,6 @@ def make_command(args: list) -> str: return " ".join(map(shlex.quote, (map(str, args)))) -def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str: - magick_args = [ - "-filter", - "Catrom", - "-density", - "72", - "-resample", - "300", - "-contrast", - "-normalize", - "-despeckle", - "-type", - "grayscale", - "-sharpen", - "1", - "-posterize", - "3", - ] - out_file = screenshot_path - - if negate: - magick_args.append("-negate") - out_file += ".negative" - - magick_args += [ - "-gamma", - "100", - "-blur", - "1x65535", - ] - out_file += ".png" - - ret = subprocess.run( - ["magick", "convert"] + magick_args + [screenshot_path, out_file], - capture_output=True, - ) - - if ret.returncode != 0: - raise MachineError( - f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}" - ) - - return out_file - - -def _perform_ocr_on_screenshot( - screenshot_path: str, model_ids: Iterable[int] -) -> list[str]: - if shutil.which("tesseract") is None: - raise MachineError("OCR requested but enableOCR is false") - - processed_image = _preprocess_screenshot(screenshot_path, negate=False) - processed_negative = _preprocess_screenshot(screenshot_path, negate=True) - - model_results = [] - for image in [screenshot_path, processed_image, processed_negative]: - for model_id in model_ids: - ret = subprocess.run( - [ - "tesseract", - image, - "-", - "--oem", - str(model_id), - "-c", - "debug_file=/dev/null", - "--psm", - "11", - ], - capture_output=True, - ) - if ret.returncode != 0: - raise MachineError(f"OCR failed with exit code {ret.returncode}") - model_results.append(ret.stdout.decode("utf-8")) - - return model_results - - def retry(fn: Callable, timeout: int = 900) -> None: """Call the given function repeatedly, with 1 second intervals, until it returns True or a timeout is reached. @@ -910,6 +833,17 @@ class Machine: self.log(f"(connecting took {toc - tic:.2f} seconds)") self.connected = True + @contextmanager + def _managed_screenshot(self) -> Generator[str]: + """ + Take a screenshot and yield the screenshot filepath. + The file will be deleted when leaving the generator. + """ + with tempfile.TemporaryDirectory() as tmpdir: + screenshot_path: str = os.path.join(tmpdir, "ppm") + self.send_monitor_command(f"screendump {screenshot_path}") + yield screenshot_path + def screenshot(self, filename: str) -> None: """ Take a picture of the display of the virtual machine, in PNG format. @@ -919,17 +853,19 @@ class Machine: filename += ".png" if "/" not in filename: filename = os.path.join(self.out_dir, filename) - tmp = f"{filename}.ppm" with self.nested( f"making screenshot {filename}", {"image": os.path.basename(filename)}, ): - self.send_monitor_command(f"screendump {tmp}") - ret = subprocess.run(f"pnmtopng '{tmp}' > '{filename}'", shell=True) - os.unlink(tmp) - if ret.returncode != 0: - raise MachineError("Cannot convert screenshot") + with self._managed_screenshot() as screenshot_path: + ret = subprocess.run( + f"pnmtopng '{screenshot_path}' > '{filename}'", shell=True + ) + if ret.returncode != 0: + raise MachineError( + f"Cannot convert screenshot (pnmtopng returned code {ret.returncode})" + ) def copy_from_host_via_shell(self, source: str, target: str) -> None: """Copy a file from the host into the guest by piping it over the @@ -1003,12 +939,6 @@ class Machine: """Debugging: Dump the contents of the TTY""" self.execute(f"fold -w 80 /dev/vcs{tty} | systemd-cat") - def _get_screen_text_variants(self, model_ids: Iterable[int]) -> list[str]: - with tempfile.TemporaryDirectory() as tmpdir: - screenshot_path = os.path.join(tmpdir, "ppm") - self.send_monitor_command(f"screendump {screenshot_path}") - return _perform_ocr_on_screenshot(screenshot_path, model_ids) - def get_screen_text_variants(self) -> list[str]: """ Return a list of different interpretations of what is currently @@ -1021,7 +951,8 @@ class Machine: This requires [`enableOCR`](#test-opt-enableOCR) to be set to `true`. ::: """ - return self._get_screen_text_variants([0, 1, 2]) + with self._managed_screenshot() as screenshot_path: + return perform_ocr_variants_on_screenshot(screenshot_path) def get_screen_text(self) -> str: """ @@ -1032,7 +963,8 @@ class Machine: This requires [`enableOCR`](#test-opt-enableOCR) to be set to `true`. ::: """ - return self._get_screen_text_variants([2])[0] + with self._managed_screenshot() as screenshot_path: + return perform_ocr_on_screenshot(screenshot_path) def wait_for_text(self, regex: str, timeout: int = 900) -> None: """ diff --git a/nixos/lib/test-driver/src/test_driver/machine/ocr.py b/nixos/lib/test-driver/src/test_driver/machine/ocr.py new file mode 100644 index 000000000000..7b1cc62dd7db --- /dev/null +++ b/nixos/lib/test-driver/src/test_driver/machine/ocr.py @@ -0,0 +1,111 @@ +import shutil +import subprocess + +from test_driver.errors import MachineError + + +def perform_ocr_on_screenshot(screenshot_path: str) -> str: + """ + Perform OCR on a screenshot that contains text. + Returns a string with all words that could be found. + """ + return perform_ocr_variants_on_screenshot(screenshot_path, False)[0] + + +def perform_ocr_variants_on_screenshot( + screenshot_path: str, variants: bool = True +) -> list[str]: + """ + Same as perform_ocr_on_screenshot but will create variants of the images + that can lead to more words being detected. + Returns a string with words for each variant. + """ + if shutil.which("tesseract") is None: + raise MachineError("OCR requested but `tesseract` is not available") + + # tesseract --help-oem + # OCR Engine modes (OEM): + # 0|tesseract_only Legacy engine only. + # 1|lstm_only Neural nets LSTM engine only. + # 2|tesseract_lstm_combined Legacy + LSTM engines. + # 3|default Default, based on what is available. + model_ids: list[int] = [0, 1, 2] if variants else [3] + + image_paths = [ + screenshot_path, + _preprocess_screenshot(screenshot_path, negate=False), + _preprocess_screenshot(screenshot_path, negate=True), + ] + + def run_tesseract(image: str, model_id: int) -> str: + ret = subprocess.run( + [ + "tesseract", + image, + "-", + "--oem", + str(model_id), + "-c", + "debug_file=/dev/null", + "--psm", + "11", + ], + capture_output=True, + ) + if ret.returncode != 0: + raise MachineError(f"OCR failed with exit code {ret.returncode}") + return ret.stdout.decode("utf-8") + + return [ + run_tesseract(image, model_id) + for image in image_paths + for model_id in model_ids + ] + + +def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str: + if shutil.which("magick") is None: + raise MachineError("OCR requested but `magick` is not available") + + magick_args = [ + "-filter", + "Catrom", + "-density", + "72", + "-resample", + "300", + "-contrast", + "-normalize", + "-despeckle", + "-type", + "grayscale", + "-sharpen", + "1", + "-posterize", + "3", + ] + out_file = screenshot_path + + if negate: + magick_args.append("-negate") + out_file += ".negative" + + magick_args += [ + "-gamma", + "100", + "-blur", + "1x65535", + ] + out_file += ".png" + + ret = subprocess.run( + ["magick", "convert"] + magick_args + [screenshot_path, out_file], + capture_output=True, + ) + + if ret.returncode != 0: + raise MachineError( + f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}" + ) + + return out_file