2021-08-22 20:39:04 +00:00
{ buildPythonPackage , fetchFromGitHub , lib
# propagatedBuildInputs
, absl-py , numpy , opt-einsum
# checkInputs
, jaxlib , pytestCheckHook
} :
buildPythonPackage rec {
pname = " j a x " ;
2021-10-09 00:23:08 +02:00
version = " 0 . 2 . 2 1 " ;
2021-08-22 20:39:04 +00:00
# Fetching from pypi doesn't allow us to run the test suite. See https://discourse.nixos.org/t/pythonremovetestsdir-hook-being-run-before-checkphase/14612/3.
src = fetchFromGitHub {
owner = " g o o g l e " ;
repo = pname ;
rev = " j a x - v ${ version } " ;
2021-10-09 00:23:08 +02:00
sha256 = " 0 5 w 1 5 7 h 6 j v 2 0 k 8 w 2 g n m l x b y c m z f 2 4 l r 5 v 3 9 2 q 0 c 5 v 0 q c q l 1 1 q 7 p n " ;
2021-08-22 20:39:04 +00:00
} ;
# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
# CPU wheel is packaged.
propagatedBuildInputs = [ absl-py numpy opt-einsum ] ;
checkInputs = [ jaxlib pytestCheckHook ] ;
# NOTE: Don't run the tests in the expiremental directory as they require flax
# which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
# Not a big deal, this is how the JAX docs suggest running the test suite
# anyhow.
pytestFlagsArray = [ " - W i g n o r e : : D e p r e c a t i o n W a r n i n g " " t e s t s / " ] ;
meta = with lib ; {
2021-09-01 21:04:02 +00:00
description = " D i f f e r e n t i a t e , c o m p i l e , a n d t r a n s f o r m N u m p y c o d e " ;
2021-08-22 20:39:04 +00:00
homepage = " h t t p s : / / g i t h u b . c o m / g o o g l e / j a x " ;
license = licenses . asl20 ;
maintainers = with maintainers ; [ samuela ] ;
} ;
}