diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index 6e70b24f67da..3504c6bf3204 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -97,7 +97,8 @@ buildPythonPackage rec { done ''; - # pip dependencies and optionally cudatoolkit. + # pip dependencies and optionally cudatoolkit. Note that cudatoolkit is + # necessary since jaxlib looks for "ptxas" in $PATH. propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11; pythonImportsCheck = [ "jaxlib" ];