nixpkgs/pkgs/development/python-modules/torch/gpu-checks.nix
2024-06-26 00:35:45 +00:00

50 lines
1 KiB
Nix

{
lib,
callPackage,
torchWithCuda,
torchWithRocm,
}:
let
accelAvailable =
{
feature,
versionAttr,
torch,
runCommandNoCC,
writers,
}:
let
name = "${torch.name}-${feature}-check";
unwrapped = writers.writePython3Bin "${name}-unwrapped" { libraries = [ torch ]; } ''
import torch
message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
assert torch.cuda.is_available() and torch.version.${versionAttr}, message
print(message)
'';
in
runCommandNoCC name
{
nativeBuildInputs = [ unwrapped ];
requiredSystemFeatures = [ feature ];
passthru = {
inherit unwrapped;
};
}
''
${name}-unwrapped
touch $out
'';
in
{
cudaAvailable = callPackage accelAvailable {
feature = "cuda";
versionAttr = "cuda";
torch = torchWithCuda;
};
rocmAvailable = callPackage accelAvailable {
feature = "rocm";
versionAttr = "hip";
torch = torchWithRocm;
};
}