0
0
Fork 0
mirror of https://github.com/NixOS/nixpkgs.git synced 2025-07-14 22:20:30 +03:00

test-driver: Factor out OCR related code to machine/ocr.py

This commit is contained in:
Jacek Galowicz 2025-06-30 11:54:21 +02:00
parent 2c8500b91d
commit 9f10c9bce8
2 changed files with 137 additions and 94 deletions

View file

@ -13,8 +13,8 @@ import sys
import tempfile
import threading
import time
from collections.abc import Callable, Iterable
from contextlib import _GeneratorContextManager, nullcontext
from collections.abc import Callable, Generator
from contextlib import _GeneratorContextManager, contextmanager, nullcontext
from pathlib import Path
from queue import Queue
from typing import Any
@ -22,6 +22,7 @@ from typing import Any
from test_driver.errors import MachineError, RequestedAssertionFailed
from test_driver.logger import AbstractLogger
from .ocr import perform_ocr_on_screenshot, perform_ocr_variants_on_screenshot
from .qmp import QMPSession
CHAR_TO_KEY = {
@ -92,84 +93,6 @@ def make_command(args: list) -> str:
return " ".join(map(shlex.quote, (map(str, args))))
def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str:
magick_args = [
"-filter",
"Catrom",
"-density",
"72",
"-resample",
"300",
"-contrast",
"-normalize",
"-despeckle",
"-type",
"grayscale",
"-sharpen",
"1",
"-posterize",
"3",
]
out_file = screenshot_path
if negate:
magick_args.append("-negate")
out_file += ".negative"
magick_args += [
"-gamma",
"100",
"-blur",
"1x65535",
]
out_file += ".png"
ret = subprocess.run(
["magick", "convert"] + magick_args + [screenshot_path, out_file],
capture_output=True,
)
if ret.returncode != 0:
raise MachineError(
f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
)
return out_file
def _perform_ocr_on_screenshot(
screenshot_path: str, model_ids: Iterable[int]
) -> list[str]:
if shutil.which("tesseract") is None:
raise MachineError("OCR requested but enableOCR is false")
processed_image = _preprocess_screenshot(screenshot_path, negate=False)
processed_negative = _preprocess_screenshot(screenshot_path, negate=True)
model_results = []
for image in [screenshot_path, processed_image, processed_negative]:
for model_id in model_ids:
ret = subprocess.run(
[
"tesseract",
image,
"-",
"--oem",
str(model_id),
"-c",
"debug_file=/dev/null",
"--psm",
"11",
],
capture_output=True,
)
if ret.returncode != 0:
raise MachineError(f"OCR failed with exit code {ret.returncode}")
model_results.append(ret.stdout.decode("utf-8"))
return model_results
def retry(fn: Callable, timeout: int = 900) -> None:
"""Call the given function repeatedly, with 1 second intervals,
until it returns True or a timeout is reached.
@ -910,6 +833,17 @@ class Machine:
self.log(f"(connecting took {toc - tic:.2f} seconds)")
self.connected = True
@contextmanager
def _managed_screenshot(self) -> Generator[str]:
"""
Take a screenshot and yield the screenshot filepath.
The file will be deleted when leaving the generator.
"""
with tempfile.TemporaryDirectory() as tmpdir:
screenshot_path: str = os.path.join(tmpdir, "ppm")
self.send_monitor_command(f"screendump {screenshot_path}")
yield screenshot_path
def screenshot(self, filename: str) -> None:
"""
Take a picture of the display of the virtual machine, in PNG format.
@ -919,17 +853,19 @@ class Machine:
filename += ".png"
if "/" not in filename:
filename = os.path.join(self.out_dir, filename)
tmp = f"{filename}.ppm"
with self.nested(
f"making screenshot {filename}",
{"image": os.path.basename(filename)},
):
self.send_monitor_command(f"screendump {tmp}")
ret = subprocess.run(f"pnmtopng '{tmp}' > '{filename}'", shell=True)
os.unlink(tmp)
if ret.returncode != 0:
raise MachineError("Cannot convert screenshot")
with self._managed_screenshot() as screenshot_path:
ret = subprocess.run(
f"pnmtopng '{screenshot_path}' > '{filename}'", shell=True
)
if ret.returncode != 0:
raise MachineError(
f"Cannot convert screenshot (pnmtopng returned code {ret.returncode})"
)
def copy_from_host_via_shell(self, source: str, target: str) -> None:
"""Copy a file from the host into the guest by piping it over the
@ -1003,12 +939,6 @@ class Machine:
"""Debugging: Dump the contents of the TTY<n>"""
self.execute(f"fold -w 80 /dev/vcs{tty} | systemd-cat")
def _get_screen_text_variants(self, model_ids: Iterable[int]) -> list[str]:
with tempfile.TemporaryDirectory() as tmpdir:
screenshot_path = os.path.join(tmpdir, "ppm")
self.send_monitor_command(f"screendump {screenshot_path}")
return _perform_ocr_on_screenshot(screenshot_path, model_ids)
def get_screen_text_variants(self) -> list[str]:
"""
Return a list of different interpretations of what is currently
@ -1021,7 +951,8 @@ class Machine:
This requires [`enableOCR`](#test-opt-enableOCR) to be set to `true`.
:::
"""
return self._get_screen_text_variants([0, 1, 2])
with self._managed_screenshot() as screenshot_path:
return perform_ocr_variants_on_screenshot(screenshot_path)
def get_screen_text(self) -> str:
"""
@ -1032,7 +963,8 @@ class Machine:
This requires [`enableOCR`](#test-opt-enableOCR) to be set to `true`.
:::
"""
return self._get_screen_text_variants([2])[0]
with self._managed_screenshot() as screenshot_path:
return perform_ocr_on_screenshot(screenshot_path)
def wait_for_text(self, regex: str, timeout: int = 900) -> None:
"""

View file

@ -0,0 +1,111 @@
import shutil
import subprocess
from test_driver.errors import MachineError
def perform_ocr_on_screenshot(screenshot_path: str) -> str:
"""
Perform OCR on a screenshot that contains text.
Returns a string with all words that could be found.
"""
return perform_ocr_variants_on_screenshot(screenshot_path, False)[0]
def perform_ocr_variants_on_screenshot(
screenshot_path: str, variants: bool = True
) -> list[str]:
"""
Same as perform_ocr_on_screenshot but will create variants of the images
that can lead to more words being detected.
Returns a string with words for each variant.
"""
if shutil.which("tesseract") is None:
raise MachineError("OCR requested but `tesseract` is not available")
# tesseract --help-oem
# OCR Engine modes (OEM):
# 0|tesseract_only Legacy engine only.
# 1|lstm_only Neural nets LSTM engine only.
# 2|tesseract_lstm_combined Legacy + LSTM engines.
# 3|default Default, based on what is available.
model_ids: list[int] = [0, 1, 2] if variants else [3]
image_paths = [
screenshot_path,
_preprocess_screenshot(screenshot_path, negate=False),
_preprocess_screenshot(screenshot_path, negate=True),
]
def run_tesseract(image: str, model_id: int) -> str:
ret = subprocess.run(
[
"tesseract",
image,
"-",
"--oem",
str(model_id),
"-c",
"debug_file=/dev/null",
"--psm",
"11",
],
capture_output=True,
)
if ret.returncode != 0:
raise MachineError(f"OCR failed with exit code {ret.returncode}")
return ret.stdout.decode("utf-8")
return [
run_tesseract(image, model_id)
for image in image_paths
for model_id in model_ids
]
def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str:
if shutil.which("magick") is None:
raise MachineError("OCR requested but `magick` is not available")
magick_args = [
"-filter",
"Catrom",
"-density",
"72",
"-resample",
"300",
"-contrast",
"-normalize",
"-despeckle",
"-type",
"grayscale",
"-sharpen",
"1",
"-posterize",
"3",
]
out_file = screenshot_path
if negate:
magick_args.append("-negate")
out_file += ".negative"
magick_args += [
"-gamma",
"100",
"-blur",
"1x65535",
]
out_file += ".png"
ret = subprocess.run(
["magick", "convert"] + magick_args + [screenshot_path, out_file],
capture_output=True,
)
if ret.returncode != 0:
raise MachineError(
f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
)
return out_file