tree-wide: use named CUDA versions and CUDA version utilities

Signed-off-by: Connor Baker <ConnorBaker01@gmail.com>
This commit is contained in:
Connor Baker 2025-05-09 17:40:26 +00:00
parent a94100d091
commit bf766a2d97
27 changed files with 113 additions and 119 deletions

View file

@ -45,7 +45,7 @@ stdenv.mkDerivation (finalAttrs: {
shopt -s globstar shopt -s globstar
for cmakelists in **/CMakeLists.*; do for cmakelists in **/CMakeLists.*; do
sed -i "s/OpenSSL::OpenSSL/OpenSSL::SSL/g" $cmakelists sed -i "s/OpenSSL::OpenSSL/OpenSSL::SSL/g" $cmakelists
${lib.optionalString (lib.versionOlder cudaPackages.cudaVersion "11.8") '' ${lib.optionalString (cudaPackages.cudaOlder "11.8") ''
sed -i 's/-gencode=arch=compute_89,code=sm_89//g' $cmakelists sed -i 's/-gencode=arch=compute_89,code=sm_89//g' $cmakelists
sed -i 's/-gencode=arch=compute_90,code=sm_90//g' $cmakelists sed -i 's/-gencode=arch=compute_90,code=sm_90//g' $cmakelists
''} ''}

View file

@ -13,7 +13,7 @@
}: }:
let let
inherit (lib) lists strings; inherit (lib) lists strings;
inherit (cudaPackages) backendStdenv cudaVersion flags; inherit (cudaPackages) backendStdenv cudaAtLeast flags;
cuda-common-redist = with cudaPackages; [ cuda-common-redist = with cudaPackages; [
(lib.getDev cuda_cudart) # cuda_runtime.h (lib.getDev cuda_cudart) # cuda_runtime.h
@ -62,7 +62,7 @@ stdenv.mkDerivation (finalAttrs: {
# Remove this once a release is made with # Remove this once a release is made with
# https://github.com/NVlabs/tiny-cuda-nn/commit/78a14fe8c292a69f54e6d0d47a09f52b777127e1 # https://github.com/NVlabs/tiny-cuda-nn/commit/78a14fe8c292a69f54e6d0d47a09f52b777127e1
postPatch = lib.optionals (strings.versionAtLeast cudaVersion "11.0") '' postPatch = lib.optionals (cudaAtLeast "11.0") ''
substituteInPlace bindings/torch/setup.py --replace-fail \ substituteInPlace bindings/torch/setup.py --replace-fail \
"-std=c++14" "-std=c++17" "-std=c++14" "-std=c++17"
''; '';

View file

@ -1,14 +1,16 @@
# Packages which have been deprecated or removed from cudaPackages # Packages which have been deprecated or removed from cudaPackages
final: prev: final: _:
let let
inherit (prev.lib) warn;
inherit (builtins) mapAttrs;
mkRenamed = mkRenamed =
oldName: oldName:
{ path, package }: { path, package }:
warn "cudaPackages.${oldName} is deprecated, use ${path} instead" package; final.lib.warn "cudaPackages.${oldName} is deprecated, use ${path} instead" package;
in in
mapAttrs mkRenamed { builtins.mapAttrs mkRenamed {
# A comment to prevent empty { } from collapsing into a single line # A comment to prevent empty { } from collapsing into a single line
cudaVersion = {
path = "cudaPackages.cudaMajorMinorVersion";
package = final.cudaMajorMinorVersion;
};
} }

View file

@ -1,5 +1,5 @@
{ {
cudaVersion, cudaMajorMinorVersion,
lib, lib,
stdenv, stdenv,
}: }:
@ -27,7 +27,7 @@ let
# Samples are built around the CUDA Toolkit, which is not available for # Samples are built around the CUDA Toolkit, which is not available for
# aarch64. Check for both CUDA version and platform. # aarch64. Check for both CUDA version and platform.
cudaVersionIsSupported = cudaVersionToHash ? ${cudaVersion}; cudaVersionIsSupported = cudaVersionToHash ? ${cudaMajorMinorVersion};
platformIsSupported = hostPlatform.isx86_64; platformIsSupported = hostPlatform.isx86_64;
isSupported = cudaVersionIsSupported && platformIsSupported; isSupported = cudaVersionIsSupported && platformIsSupported;
@ -36,8 +36,7 @@ let
final: _: final: _:
lib.attrsets.optionalAttrs isSupported { lib.attrsets.optionalAttrs isSupported {
cuda-samples = final.callPackage ./generic.nix { cuda-samples = final.callPackage ./generic.nix {
inherit cudaVersion; hash = cudaVersionToHash.${cudaMajorMinorVersion};
hash = cudaVersionToHash.${cudaVersion};
}; };
}; };
in in

View file

@ -3,7 +3,7 @@
backendStdenv, backendStdenv,
cmake, cmake,
cudatoolkit, cudatoolkit,
cudaVersion, cudaMajorMinorVersion,
fetchFromGitHub, fetchFromGitHub,
fetchpatch, fetchpatch,
freeimage, freeimage,
@ -20,7 +20,7 @@ backendStdenv.mkDerivation (finalAttrs: {
strictDeps = true; strictDeps = true;
pname = "cuda-samples"; pname = "cuda-samples";
version = cudaVersion; version = cudaMajorMinorVersion;
src = fetchFromGitHub { src = fetchFromGitHub {
owner = "NVIDIA"; owner = "NVIDIA";

View file

@ -1,4 +1,4 @@
{ cudaVersion, lib }: { cudaMajorMinorVersion, lib }:
let let
inherit (lib) attrsets modules trivial; inherit (lib) attrsets modules trivial;
redistName = "cuda"; redistName = "cuda";
@ -23,10 +23,10 @@ let
}; };
# Check if the current CUDA version is supported. # Check if the current CUDA version is supported.
cudaVersionMappingExists = builtins.hasAttr cudaVersion cudaVersionMap; cudaVersionMappingExists = builtins.hasAttr cudaMajorMinorVersion cudaVersionMap;
# fullCudaVersion : String # fullCudaVersion : String
fullCudaVersion = cudaVersionMap.${cudaVersion}; fullCudaVersion = cudaVersionMap.${cudaMajorMinorVersion};
evaluatedModules = modules.evalModules { evaluatedModules = modules.evalModules {
modules = [ modules = [

View file

@ -1,5 +1,5 @@
{ {
cudaVersion, cudaMajorMinorVersion,
runPatches ? [ ], runPatches ? [ ],
autoPatchelfHook, autoPatchelfHook,
autoAddDriverRunpath, autoAddDriverRunpath,
@ -54,7 +54,7 @@
let let
# Version info for the classic cudatoolkit packages that contain everything that is in redist. # Version info for the classic cudatoolkit packages that contain everything that is in redist.
releases = builtins.import ./releases.nix; releases = builtins.import ./releases.nix;
release = releases.${cudaVersion}; release = releases.${cudaMajorMinorVersion};
in in
backendStdenv.mkDerivation rec { backendStdenv.mkDerivation rec {

View file

@ -4,7 +4,7 @@
backendStdenv, backendStdenv,
cudaOlder, cudaOlder,
cudatoolkit-legacy-runfile, cudatoolkit-legacy-runfile,
cudaVersion, cudaMajorMinorVersion,
cuda_cccl ? null, cuda_cccl ? null,
cuda_cudart ? null, cuda_cudart ? null,
cuda_cuobjdump ? null, cuda_cuobjdump ? null,
@ -66,8 +66,8 @@ if cudaOlder "11.4" then
cudatoolkit-legacy-runfile cudatoolkit-legacy-runfile
else else
symlinkJoin rec { symlinkJoin rec {
name = "cuda-merged-${cudaVersion}"; name = "cuda-merged-${cudaMajorMinorVersion}";
version = cudaVersion; version = cudaMajorMinorVersion;
paths = builtins.concatMap getAllOutputs allPackages; paths = builtins.concatMap getAllOutputs allPackages;

View file

@ -3,7 +3,7 @@
{ {
lib, lib,
stdenv, stdenv,
cudaVersion, cudaMajorMinorVersion,
flags, flags,
mkVersionedPackageName, mkVersionedPackageName,
}: }:
@ -54,15 +54,6 @@ let
releaseGrabber releaseGrabber
]) cusparseltVersions; ]) cusparseltVersions;
# Our cudaVersion tells us which version of CUDA we're building against.
# The subdirectories in lib/ tell us which versions of CUDA are supported.
# Typically the names will look like this:
#
# - 10.2
# - 11
# - 11.0
# - 12
# A release is supported if it has a libPath that matches our CUDA version for our platform. # A release is supported if it has a libPath that matches our CUDA version for our platform.
# LibPath are not constant across the same release -- one platform may support fewer # LibPath are not constant across the same release -- one platform may support fewer
# CUDA versions than another. # CUDA versions than another.

View file

@ -13,7 +13,7 @@
# - Instead of providing different releases for each version of CUDA, CuTensor has multiple subdirectories in `lib` # - Instead of providing different releases for each version of CUDA, CuTensor has multiple subdirectories in `lib`
# -- one for each version of CUDA. # -- one for each version of CUDA.
{ {
cudaVersion, cudaMajorMinorVersion,
flags, flags,
lib, lib,
mkVersionedPackageName, mkVersionedPackageName,
@ -73,7 +73,7 @@ let
releaseGrabber releaseGrabber
]) cutensorVersions; ]) cutensorVersions;
# Our cudaVersion tells us which version of CUDA we're building against. # Our cudaMajorMinorVersion tells us which version of CUDA we're building against.
# The subdirectories in lib/ tell us which versions of CUDA are supported. # The subdirectories in lib/ tell us which versions of CUDA are supported.
# Typically the names will look like this: # Typically the names will look like this:
# #
@ -85,10 +85,9 @@ let
# libPath :: String # libPath :: String
libPath = libPath =
let let
cudaMajorMinor = versions.majorMinor cudaVersion; cudaMajorVersion = versions.major cudaMajorMinorVersion;
cudaMajor = versions.major cudaVersion;
in in
if cudaMajorMinor == "10.2" then cudaMajorMinor else cudaMajor; if cudaMajorMinorVersion == "10.2" then cudaMajorMinorVersion else cudaMajorVersion;
# A release is supported if it has a libPath that matches our CUDA version for our platform. # A release is supported if it has a libPath that matches our CUDA version for our platform.
# LibPath are not constant across the same release -- one platform may support fewer # LibPath are not constant across the same release -- one platform may support fewer

View file

@ -1,7 +1,7 @@
{ {
cudaOlder, cudaOlder,
cudatoolkit, cudatoolkit,
cudaVersion, cudaMajorMinorVersion,
fetchurl, fetchurl,
lib, lib,
libcublas ? null, # cuDNN uses CUDA Toolkit on old releases, where libcublas is not available. libcublas ? null, # cuDNN uses CUDA Toolkit on old releases, where libcublas is not available.
@ -25,7 +25,7 @@ finalAttrs: prevAttrs: {
cudaTooOld = cudaOlder finalAttrs.passthru.featureRelease.minCudaVersion; cudaTooOld = cudaOlder finalAttrs.passthru.featureRelease.minCudaVersion;
cudaTooNew = cudaTooNew =
(finalAttrs.passthru.featureRelease.maxCudaVersion != null) (finalAttrs.passthru.featureRelease.maxCudaVersion != null)
&& strings.versionOlder finalAttrs.passthru.featureRelease.maxCudaVersion cudaVersion; && strings.versionOlder finalAttrs.passthru.featureRelease.maxCudaVersion cudaMajorMinorVersion;
in in
prevAttrs.badPlatformsConditions or { } prevAttrs.badPlatformsConditions or { }
// { // {

View file

@ -1,7 +1,7 @@
{ {
cudaOlder, cudaOlder,
cudaPackages, cudaPackages,
cudaVersion, cudaMajorMinorVersion,
lib, lib,
mkVersionedPackageName, mkVersionedPackageName,
patchelf, patchelf,
@ -30,7 +30,7 @@ finalAttrs: prevAttrs: {
cudaTooOld = cudaOlder finalAttrs.passthru.featureRelease.minCudaVersion; cudaTooOld = cudaOlder finalAttrs.passthru.featureRelease.minCudaVersion;
cudaTooNew = cudaTooNew =
(finalAttrs.passthru.featureRelease.maxCudaVersion != null) (finalAttrs.passthru.featureRelease.maxCudaVersion != null)
&& strings.versionOlder finalAttrs.passthru.featureRelease.maxCudaVersion cudaVersion; && strings.versionOlder finalAttrs.passthru.featureRelease.maxCudaVersion cudaMajorMinorVersion;
cudnnVersionIsSpecified = finalAttrs.passthru.featureRelease.cudnnVersion != null; cudnnVersionIsSpecified = finalAttrs.passthru.featureRelease.cudnnVersion != null;
cudnnVersionSpecified = versions.majorMinor finalAttrs.passthru.featureRelease.cudnnVersion; cudnnVersionSpecified = versions.majorMinor finalAttrs.passthru.featureRelease.cudnnVersion;
cudnnVersionProvided = versions.majorMinor finalAttrs.passthru.cudnn.version; cudnnVersionProvided = versions.majorMinor finalAttrs.passthru.cudnn.version;
@ -52,7 +52,7 @@ finalAttrs: prevAttrs: {
inherit (finalAttrs.passthru.redistribRelease) hash; inherit (finalAttrs.passthru.redistribRelease) hash;
message = '' message = ''
To use the TensorRT derivation, you must join the NVIDIA Developer Program and To use the TensorRT derivation, you must join the NVIDIA Developer Program and
download the ${finalAttrs.version} TAR package for CUDA ${cudaVersion} from download the ${finalAttrs.version} TAR package for CUDA ${cudaMajorMinorVersion} from
${finalAttrs.meta.homepage}. ${finalAttrs.meta.homepage}.
Once you have downloaded the file, add it to the store with the following Once you have downloaded the file, add it to the store with the following
@ -96,7 +96,7 @@ finalAttrs: prevAttrs: {
''; '';
passthru = prevAttrs.passthru or { } // { passthru = prevAttrs.passthru or { } // {
useCudatoolkitRunfile = strings.versionOlder cudaVersion "11.3.999"; useCudatoolkitRunfile = strings.versionOlder cudaMajorMinorVersion "11.3.999";
# The CUDNN used with TensorRT. # The CUDNN used with TensorRT.
# If null, the default cudnn derivation will be used. # If null, the default cudnn derivation will be used.
# If a version is specified, the cudnn derivation with that version will be used, # If a version is specified, the cudnn derivation with that version will be used,

View file

@ -6,7 +6,7 @@
cudaCapabilities ? (config.cudaCapabilities or [ ]), cudaCapabilities ? (config.cudaCapabilities or [ ]),
cudaForwardCompat ? (config.cudaForwardCompat or true), cudaForwardCompat ? (config.cudaForwardCompat or true),
lib, lib,
cudaVersion, cudaMajorMinorVersion,
stdenv, stdenv,
# gpus :: List Gpu # gpus :: List Gpu
gpus, gpus,
@ -44,9 +44,9 @@ let
gpu: gpu:
let let
inherit (gpu) minCudaVersion maxCudaVersion; inherit (gpu) minCudaVersion maxCudaVersion;
lowerBoundSatisfied = strings.versionAtLeast cudaVersion minCudaVersion; lowerBoundSatisfied = strings.versionAtLeast cudaMajorMinorVersion minCudaVersion;
upperBoundSatisfied = upperBoundSatisfied =
(maxCudaVersion == null) || !(strings.versionOlder maxCudaVersion cudaVersion); (maxCudaVersion == null) || !(strings.versionOlder maxCudaVersion cudaMajorMinorVersion);
in in
lowerBoundSatisfied && upperBoundSatisfied; lowerBoundSatisfied && upperBoundSatisfied;
@ -57,7 +57,7 @@ let
let let
inherit (gpu) dontDefaultAfter isJetson; inherit (gpu) dontDefaultAfter isJetson;
newGpu = dontDefaultAfter == null; newGpu = dontDefaultAfter == null;
recentGpu = newGpu || strings.versionAtLeast dontDefaultAfter cudaVersion; recentGpu = newGpu || strings.versionAtLeast dontDefaultAfter cudaMajorMinorVersion;
in in
recentGpu && !isJetson; recentGpu && !isJetson;
@ -289,12 +289,14 @@ assert
}; };
actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value; actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
in in
asserts.assertMsg ((strings.versionAtLeast cudaVersion "11.2") -> (expected == actualWrapped)) '' asserts.assertMsg
This test should only fail when using a version of CUDA older than 11.2, the first to support ((strings.versionAtLeast cudaMajorMinorVersion "11.2") -> (expected == actualWrapped))
8.6. ''
Expected: ${builtins.toJSON expected} This test should only fail when using a version of CUDA older than 11.2, the first to support
Actual: ${builtins.toJSON actualWrapped} 8.6.
''; Expected: ${builtins.toJSON expected}
Actual: ${builtins.toJSON actualWrapped}
'';
# Check mixed Jetson and non-Jetson devices # Check mixed Jetson and non-Jetson devices
assert assert
let let

View file

@ -1,7 +1,7 @@
{ {
# callPackage-provided arguments # callPackage-provided arguments
lib, lib,
cudaVersion, cudaMajorMinorVersion,
flags, flags,
stdenv, stdenv,
# Expected to be passed by the caller # Expected to be passed by the caller
@ -69,8 +69,8 @@ let
isSupported = isSupported =
package: package:
redistArch == package.redistArch redistArch == package.redistArch
&& strings.versionAtLeast cudaVersion package.minCudaVersion && strings.versionAtLeast cudaMajorMinorVersion package.minCudaVersion
&& strings.versionAtLeast package.maxCudaVersion cudaVersion; && strings.versionAtLeast package.maxCudaVersion cudaMajorMinorVersion;
# Get all of the packages for our given platform. # Get all of the packages for our given platform.
# redistArch :: String # redistArch :: String

View file

@ -7,7 +7,7 @@
# E.g. for cudaPackages_11_8 we use gcc11 with gcc12's libstdc++ # E.g. for cudaPackages_11_8 we use gcc11 with gcc12's libstdc++
# Cf. https://github.com/NixOS/nixpkgs/pull/218265 for context # Cf. https://github.com/NixOS/nixpkgs/pull/218265 for context
{ {
cudaVersion, cudaMajorMinorVersion,
lib, lib,
nvccCompatibilities, nvccCompatibilities,
pkgs, pkgs,
@ -16,7 +16,7 @@
}: }:
let let
gccMajorVersion = nvccCompatibilities.${cudaVersion}.gccMaxMajorVersion; gccMajorVersion = nvccCompatibilities.${cudaMajorMinorVersion}.gccMaxMajorVersion;
cudaStdenv = stdenvAdapters.useLibsFrom stdenv pkgs."gcc${gccMajorVersion}Stdenv"; cudaStdenv = stdenvAdapters.useLibsFrom stdenv pkgs."gcc${gccMajorVersion}Stdenv";
passthruExtra = { passthruExtra = {
# cudaPackages.backendStdenv.nixpkgsCompatibleLibstdcxx has been removed, # cudaPackages.backendStdenv.nixpkgsCompatibleLibstdcxx has been removed,

View file

@ -17,8 +17,9 @@ let
cuda_cccl cuda_cccl
cuda_cudart cuda_cudart
cuda_nvcc cuda_nvcc
cudaAtLeast
cudaOlder
cudatoolkit cudatoolkit
cudaVersion
nccl nccl
; ;
in in
@ -43,25 +44,25 @@ backendStdenv.mkDerivation (finalAttrs: {
nativeBuildInputs = nativeBuildInputs =
[ which ] [ which ]
++ lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ] ++ lib.optionals (cudaOlder "11.4") [ cudatoolkit ]
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ cuda_nvcc ]; ++ lib.optionals (cudaAtLeast "11.4") [ cuda_nvcc ];
buildInputs = buildInputs =
[ nccl ] [ nccl ]
++ lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ] ++ lib.optionals (cudaOlder "11.4") [ cudatoolkit ]
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ ++ lib.optionals (cudaAtLeast "11.4") [
cuda_nvcc # crt/host_config.h cuda_nvcc # crt/host_config.h
cuda_cudart cuda_cudart
] ]
++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [ ++ lib.optionals (cudaAtLeast "12.0") [
cuda_cccl # <nv/target> cuda_cccl # <nv/target>
] ]
++ lib.optionals mpiSupport [ mpi ]; ++ lib.optionals mpiSupport [ mpi ];
makeFlags = makeFlags =
[ "NCCL_HOME=${nccl}" ] [ "NCCL_HOME=${nccl}" ]
++ lib.optionals (lib.versionOlder cudaVersion "11.4") [ "CUDA_HOME=${cudatoolkit}" ] ++ lib.optionals (cudaOlder "11.4") [ "CUDA_HOME=${cudatoolkit}" ]
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ "CUDA_HOME=${cuda_nvcc}" ] ++ lib.optionals (cudaAtLeast "11.4") [ "CUDA_HOME=${cuda_nvcc}" ]
++ lib.optionals mpiSupport [ "MPI=1" ]; ++ lib.optionals mpiSupport [ "MPI=1" ];
enableParallelBuilding = true; enableParallelBuilding = true;

View file

@ -98,6 +98,6 @@ stdenv.mkDerivation rec {
hexa hexa
misuzu misuzu
]; ];
broken = (lib.versionOlder cudaPackages.cudaVersion "11.4") || !(withCuDNN -> withCUDA); broken = (cudaPackages.cudaOlder "11.4") || !(withCuDNN -> withCUDA);
}; };
} }

View file

@ -14,9 +14,9 @@ let
version = "0.45.1"; version = "0.45.1";
inherit (torch) cudaPackages cudaSupport; inherit (torch) cudaPackages cudaSupport;
inherit (cudaPackages) cudaVersion; inherit (cudaPackages) cudaMajorMinorVersion;
cudaVersionString = lib.replaceStrings [ "." ] [ "" ] (lib.versions.majorMinor cudaVersion); cudaMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] cudaMajorMinorVersion;
# NOTE: torchvision doesn't use cudnn; torch does! # NOTE: torchvision doesn't use cudnn; torch does!
# For this reason it is not included. # For this reason it is not included.
@ -32,7 +32,7 @@ let
]; ];
cuda-native-redist = symlinkJoin { cuda-native-redist = symlinkJoin {
name = "cuda-native-redist-${cudaVersion}"; name = "cuda-native-redist-${cudaMajorMinorVersion}";
paths = paths =
with cudaPackages; with cudaPackages;
[ [
@ -45,7 +45,7 @@ let
}; };
cuda-redist = symlinkJoin { cuda-redist = symlinkJoin {
name = "cuda-redist-${cudaVersion}"; name = "cuda-redist-${cudaMajorMinorVersion}";
paths = cuda-common-redist; paths = cuda-common-redist;
}; };
in in
@ -73,7 +73,7 @@ buildPythonPackage {
--replace-fail "if cuda_specs:" "if True:" \ --replace-fail "if cuda_specs:" "if True:" \
--replace-fail \ --replace-fail \
"cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \ "cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
"cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${cudaVersionString}.so'" "cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${cudaMajorMinorVersionString}.so'"
''; '';
nativeBuildInputs = [ nativeBuildInputs = [

View file

@ -45,7 +45,7 @@ let
# cusparselt # cusparselt
]; ];
cudatoolkit-joined = symlinkJoin { cudatoolkit-joined = symlinkJoin {
name = "cudatoolkit-joined-${cudaPackages.cudaVersion}"; name = "cudatoolkit-joined-${cudaPackages.cudaMajorMinorVersion}";
paths = paths =
outpaths outpaths
++ lib.concatMap (f: lib.map f outpaths) [ ++ lib.concatMap (f: lib.map f outpaths) [

View file

@ -13,7 +13,7 @@
}: }:
let let
inherit (jaxlib) version; inherit (jaxlib) version;
inherit (cudaPackages) cudaVersion; inherit (cudaPackages) cudaAtLeast;
cudaLibPath = lib.makeLibraryPath ( cudaLibPath = lib.makeLibraryPath (
with cudaPackages; with cudaPackages;
@ -101,7 +101,6 @@ buildPythonPackage rec {
platforms = lib.platforms.linux; platforms = lib.platforms.linux;
# see CUDA compatibility matrix # see CUDA compatibility matrix
# https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
broken = broken = !(cudaAtLeast "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1");
!(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1");
}; };
} }

View file

@ -13,7 +13,7 @@
}: }:
let let
inherit (jaxlib) version; inherit (jaxlib) version;
inherit (cudaPackages) cudaVersion; inherit (cudaPackages) cudaAtLeast;
inherit (jax-cuda12-pjrt) cudaLibPath; inherit (jax-cuda12-pjrt) cudaLibPath;
getSrcFromPypi = getSrcFromPypi =
@ -133,7 +133,6 @@ buildPythonPackage {
platforms = lib.platforms.linux; platforms = lib.platforms.linux;
# see CUDA compatibility matrix # see CUDA compatibility matrix
# https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
broken = broken = !(cudaAtLeast "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1");
!(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1");
}; };
} }

View file

@ -53,7 +53,7 @@
let let
inherit (cudaPackages) inherit (cudaPackages)
cudaFlags cudaFlags
cudaVersion cudaMajorMinorVersion
nccl nccl
; ;
@ -317,7 +317,7 @@ let
build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}" build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}"
build --action_env CUDNN_INSTALL_PATH="${cudnnMerged}" build --action_env CUDNN_INSTALL_PATH="${cudnnMerged}"
build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnnMerged},${lib.getDev nccl}" build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnnMerged},${lib.getDev nccl}"
build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}" build --action_env TF_CUDA_VERSION="${cudaMajorMinorVersion}"
build --action_env TF_CUDNN_VERSION="${lib.versions.major cudaPackages.cudnn.version}" build --action_env TF_CUDNN_VERSION="${lib.versions.major cudaPackages.cudnn.version}"
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}" build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
'' ''

View file

@ -149,7 +149,7 @@ let
]; ];
cudatoolkitDevMerged = symlinkJoin { cudatoolkitDevMerged = symlinkJoin {
name = "cuda-${cudaPackages.cudaVersion}-dev-merged"; name = "cuda-${cudaPackages.cudaMajorMinorVersion}-dev-merged";
paths = lib.concatMap (p: [ paths = lib.concatMap (p: [
(lib.getBin p) (lib.getBin p)
(lib.getDev p) (lib.getDev p)

View file

@ -232,7 +232,8 @@ let
# effectiveMagma.cudaPackages, making torch too strict in cudaPackages. # effectiveMagma.cudaPackages, making torch too strict in cudaPackages.
# In particular, this triggered warnings from cuda's `aliases.nix` # In particular, this triggered warnings from cuda's `aliases.nix`
"Magma cudaPackages does not match cudaPackages" = "Magma cudaPackages does not match cudaPackages" =
cudaSupport && (effectiveMagma.cudaPackages.cudaVersion != cudaPackages.cudaVersion); cudaSupport
&& (effectiveMagma.cudaPackages.cudaMajorMinorVersion != cudaPackages.cudaMajorMinorVersion);
}; };
unroll-src = writeShellScript "unroll-src" '' unroll-src = writeShellScript "unroll-src" ''
@ -414,12 +415,12 @@ buildPythonPackage rec {
[ [
(lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}") (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
# (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true) # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
(lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaVersion) (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
] ]
++ lib.optionals cudaSupport [ ++ lib.optionals cudaSupport [
# Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
# Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363 # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
(lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaVersion) (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
]; ];
preBuild = '' preBuild = ''
@ -529,10 +530,10 @@ buildPythonPackage rec {
# Some platforms do not support NCCL (i.e., Jetson) # Some platforms do not support NCCL (i.e., Jetson)
nccl # Provides nccl.h AND a static copy of NCCL! nccl # Provides nccl.h AND a static copy of NCCL!
] ]
++ lists.optionals (strings.versionOlder cudaVersion "11.8") [ ++ lists.optionals (cudaOlder "11.8") [
cuda_nvprof # <cuda_profiler_api.h> cuda_nvprof # <cuda_profiler_api.h>
] ]
++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [ ++ lists.optionals (cudaAtLeast "11.8") [
cuda_profiler_api # <cuda_profiler_api.h> cuda_profiler_api # <cuda_profiler_api.h>
] ]
) )

View file

@ -379,7 +379,7 @@ buildPythonPackage rec {
(lib.cmakeFeature "CUTLASS_NVCC_ARCHS_ENABLED" "${cudaPackages.cudaFlags.cmakeCudaArchitecturesString (lib.cmakeFeature "CUTLASS_NVCC_ARCHS_ENABLED" "${cudaPackages.cudaFlags.cmakeCudaArchitecturesString
}") }")
(lib.cmakeFeature "CUDA_TOOLKIT_ROOT_DIR" "${symlinkJoin { (lib.cmakeFeature "CUDA_TOOLKIT_ROOT_DIR" "${symlinkJoin {
name = "cuda-merged-${cudaPackages.cudaVersion}"; name = "cuda-merged-${cudaPackages.cudaMajorMinorVersion}";
paths = builtins.concatMap getAllOutputs mergedCudaLibraries; paths = builtins.concatMap getAllOutputs mergedCudaLibraries;
}}") }}")
(lib.cmakeFeature "CAFFE2_USE_CUDNN" "ON") (lib.cmakeFeature "CAFFE2_USE_CUDNN" "ON")

View file

@ -2745,25 +2745,25 @@ with pkgs;
cron = isc-cron; cron = isc-cron;
cudaPackages_11_0 = callPackage ./cuda-packages.nix { cudaVersion = "11.0"; }; cudaPackages_11_0 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "11.0"; };
cudaPackages_11_1 = callPackage ./cuda-packages.nix { cudaVersion = "11.1"; }; cudaPackages_11_1 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "11.1"; };
cudaPackages_11_2 = callPackage ./cuda-packages.nix { cudaVersion = "11.2"; }; cudaPackages_11_2 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "11.2"; };
cudaPackages_11_3 = callPackage ./cuda-packages.nix { cudaVersion = "11.3"; }; cudaPackages_11_3 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "11.3"; };
cudaPackages_11_4 = callPackage ./cuda-packages.nix { cudaVersion = "11.4"; }; cudaPackages_11_4 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "11.4"; };
cudaPackages_11_5 = callPackage ./cuda-packages.nix { cudaVersion = "11.5"; }; cudaPackages_11_5 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "11.5"; };
cudaPackages_11_6 = callPackage ./cuda-packages.nix { cudaVersion = "11.6"; }; cudaPackages_11_6 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "11.6"; };
cudaPackages_11_7 = callPackage ./cuda-packages.nix { cudaVersion = "11.7"; }; cudaPackages_11_7 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "11.7"; };
cudaPackages_11_8 = callPackage ./cuda-packages.nix { cudaVersion = "11.8"; }; cudaPackages_11_8 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "11.8"; };
cudaPackages_11 = recurseIntoAttrs cudaPackages_11_8; cudaPackages_11 = recurseIntoAttrs cudaPackages_11_8;
cudaPackages_12_0 = callPackage ./cuda-packages.nix { cudaVersion = "12.0"; }; cudaPackages_12_0 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "12.0"; };
cudaPackages_12_1 = callPackage ./cuda-packages.nix { cudaVersion = "12.1"; }; cudaPackages_12_1 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "12.1"; };
cudaPackages_12_2 = callPackage ./cuda-packages.nix { cudaVersion = "12.2"; }; cudaPackages_12_2 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "12.2"; };
cudaPackages_12_3 = callPackage ./cuda-packages.nix { cudaVersion = "12.3"; }; cudaPackages_12_3 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "12.3"; };
cudaPackages_12_4 = callPackage ./cuda-packages.nix { cudaVersion = "12.4"; }; cudaPackages_12_4 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "12.4"; };
cudaPackages_12_5 = callPackage ./cuda-packages.nix { cudaVersion = "12.5"; }; cudaPackages_12_5 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "12.5"; };
cudaPackages_12_6 = callPackage ./cuda-packages.nix { cudaVersion = "12.6"; }; cudaPackages_12_6 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "12.6"; };
cudaPackages_12_8 = callPackage ./cuda-packages.nix { cudaVersion = "12.8"; }; cudaPackages_12_8 = callPackage ./cuda-packages.nix { cudaMajorMinorVersion = "12.8"; };
cudaPackages_12 = cudaPackages_12_8; # Latest supported by cudnn cudaPackages_12 = cudaPackages_12_8; # Latest supported by cudnn
cudaPackages = recurseIntoAttrs cudaPackages_12; cudaPackages = recurseIntoAttrs cudaPackages_12;

View file

@ -22,7 +22,7 @@
# I've (@connorbaker) attempted to do that, though I'm unsure of how this will interact with overrides. # I've (@connorbaker) attempted to do that, though I'm unsure of how this will interact with overrides.
{ {
config, config,
cudaVersion, cudaMajorMinorVersion,
lib, lib,
newScope, newScope,
pkgs, pkgs,
@ -44,7 +44,7 @@ let
flags = import ../development/cuda-modules/flags.nix { flags = import ../development/cuda-modules/flags.nix {
inherit inherit
config config
cudaVersion cudaMajorMinorVersion
gpus gpus
lib lib
stdenv stdenv
@ -56,7 +56,7 @@ let
passthruFunction = final: { passthruFunction = final: {
inherit inherit
cudaVersion cudaMajorMinorVersion
fixups fixups
flags flags
gpus gpus
@ -64,10 +64,9 @@ let
nvccCompatibilities nvccCompatibilities
pkgs pkgs
; ;
cudaMajorVersion = versions.major cudaVersion; cudaMajorVersion = versions.major cudaMajorMinorVersion;
cudaMajorMinorVersion = versions.majorMinor cudaVersion; cudaOlder = strings.versionOlder cudaMajorMinorVersion;
cudaOlder = strings.versionOlder cudaVersion; cudaAtLeast = strings.versionAtLeast cudaMajorMinorVersion;
cudaAtLeast = strings.versionAtLeast cudaVersion;
# NOTE: mkVersionedPackageName is an internal, implementation detail and should not be relied on by outside consumers. # NOTE: mkVersionedPackageName is an internal, implementation detail and should not be relied on by outside consumers.
# It may be removed in the future. # It may be removed in the future.
@ -144,10 +143,10 @@ let
directory = ../development/cuda-modules/packages; directory = ../development/cuda-modules/packages;
} }
) )
(import ../development/cuda-modules/cuda/extension.nix { inherit cudaVersion lib; }) (import ../development/cuda-modules/cuda/extension.nix { inherit cudaMajorMinorVersion lib; })
(import ../development/cuda-modules/generic-builders/multiplex.nix { (import ../development/cuda-modules/generic-builders/multiplex.nix {
inherit inherit
cudaVersion cudaMajorMinorVersion
flags flags
lib lib
mkVersionedPackageName mkVersionedPackageName
@ -160,7 +159,7 @@ let
}) })
(import ../development/cuda-modules/cutensor/extension.nix { (import ../development/cuda-modules/cutensor/extension.nix {
inherit inherit
cudaVersion cudaMajorMinorVersion
flags flags
lib lib
mkVersionedPackageName mkVersionedPackageName
@ -169,7 +168,7 @@ let
}) })
(import ../development/cuda-modules/cusparselt/extension.nix { (import ../development/cuda-modules/cusparselt/extension.nix {
inherit inherit
cudaVersion cudaMajorMinorVersion
flags flags
lib lib
mkVersionedPackageName mkVersionedPackageName
@ -178,7 +177,7 @@ let
}) })
(import ../development/cuda-modules/generic-builders/multiplex.nix { (import ../development/cuda-modules/generic-builders/multiplex.nix {
inherit inherit
cudaVersion cudaMajorMinorVersion
flags flags
lib lib
mkVersionedPackageName mkVersionedPackageName
@ -189,7 +188,9 @@ let
releasesModule = ../development/cuda-modules/tensorrt/releases.nix; releasesModule = ../development/cuda-modules/tensorrt/releases.nix;
shimsFn = ../development/cuda-modules/tensorrt/shims.nix; shimsFn = ../development/cuda-modules/tensorrt/shims.nix;
}) })
(import ../development/cuda-modules/cuda-samples/extension.nix { inherit cudaVersion lib stdenv; }) (import ../development/cuda-modules/cuda-samples/extension.nix {
inherit cudaMajorMinorVersion lib stdenv;
})
(import ../development/cuda-modules/cuda-library-samples/extension.nix { inherit lib stdenv; }) (import ../development/cuda-modules/cuda-library-samples/extension.nix { inherit lib stdenv; })
] ]
++ lib.optionals config.allowAliases [ (import ../development/cuda-modules/aliases.nix) ] ++ lib.optionals config.allowAliases [ (import ../development/cuda-modules/aliases.nix) ]