diff --git a/pkgs/development/python-modules/dm-haiku/default.nix b/pkgs/development/python-modules/dm-haiku/default.nix index 2c4c24b93670..87de983330c2 100644 --- a/pkgs/development/python-modules/dm-haiku/default.nix +++ b/pkgs/development/python-modules/dm-haiku/default.nix @@ -1,66 +1,49 @@ { buildPythonPackage -, chex -, cloudpickle -, dill -, dm-tree , fetchFromGitHub -, jaxlib -, jmp +, callPackage , lib -, pytest-xdist -, pytestCheckHook +, jmp , tabulate -, tensorflow +, jaxlib }: buildPythonPackage rec { pname = "dm-haiku"; - version = "0.0.6"; + version = "0.0.7"; src = fetchFromGitHub { owner = "deepmind"; repo = pname; rev = "v${version}"; - hash = "sha256-qvKMeGPiWXvvyV+GZdTWdsC6Wp08AmP8nDtWk7sZtqM="; + hash = "sha256-Qa3g3vOPZJt/wBjjuZHAcFUz/gwN/yvirV/8V9CnIko="; }; - propagatedBuildInputs = [ - jmp - tabulate + outputs = [ + "out" + "testsout" ]; - checkInputs = [ - chex - cloudpickle - dill - dm-tree + propagatedBuildInputs = [ jaxlib - pytest-xdist - pytestCheckHook - tensorflow + jmp + tabulate ]; pythonImportsCheck = [ "haiku" ]; - disabledTestPaths = [ - # These tests require `bsuite` which isn't packaged in `nixpkgs`. - "examples/impala_lite_test.py" - "examples/impala/actor_test.py" - "examples/impala/learner_test.py" - # This test breaks on multiple cases with TF-related errors, - # likely that's the reason the upstream uses TF-nightly for tests? - # `nixpkgs` doesn't have the corresponding TF version packaged. - "haiku/_src/integration/jax2tf_test.py" - # `TypeError: lax.conv_general_dilated requires arguments to have the same dtypes, got float32, float16`. - "haiku/_src/integration/numpy_inputs_test.py" - ]; + postInstall = '' + mkdir $testsout + cp -R examples $testsout/examples + ''; - disabledTests = [ - # See https://github.com/deepmind/dm-haiku/issues/366. - "test_jit_Recurrent" - ]; + # check in passthru.tests.pytest to escape infinite recursion with bsuite + doCheck = false; + + passthru.tests = { + pytest = callPackage ./tests.nix { }; + }; meta = with lib; { description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet."; diff --git a/pkgs/development/python-modules/dm-haiku/tests.nix b/pkgs/development/python-modules/dm-haiku/tests.nix new file mode 100644 index 000000000000..93a4f3cd4795 --- /dev/null +++ b/pkgs/development/python-modules/dm-haiku/tests.nix @@ -0,0 +1,68 @@ +{ stdenv +, buildPythonPackage +, dm-haiku +, chex +, cloudpickle +, dill +, dm-tree +, jaxlib +, pytest-xdist +, pytestCheckHook +, tensorflow +, bsuite +, frozendict +, dm-env +, scikitimage +, rlax +, distrax +, tensorflow-probability +, optax }: + +buildPythonPackage rec { + pname = "dm-haiku-tests"; + inherit (dm-haiku) version; + + src = dm-haiku.testsout; + + dontBuild = true; + dontInstall = true; + + checkInputs = [ + bsuite + chex + cloudpickle + dill + distrax + dm-env + dm-haiku + dm-tree + frozendict + jaxlib + pytest-xdist + pytestCheckHook + optax + rlax + scikitimage + tensorflow + tensorflow-probability + ]; + + disabledTests = [ + # See https://github.com/deepmind/dm-haiku/issues/366. + "test_jit_Recurrent" + # Assertion errors + "test_connect_conv_padding_function_same0" + "test_connect_conv_padding_function_valid0" + "test_connect_conv_padding_function_same1" + "test_connect_conv_padding_function_same2" + "test_connect_conv_padding_function_valid1" + "test_connect_conv_padding_function_valid2" + "test_invalid_axis_ListString" + "test_invalid_axis_String" + "test_simple_case" + "test_simple_case_with_scale" + "test_slice_axis" + "test_zero_inputs" + ]; + +}