nixpkgs/pkgs/development/python-modules/skorch/default.nix
2025-06-09 01:06:38 +02:00

105 lines
2.5 KiB
Nix

{
lib,
stdenv,
buildPythonPackage,
fetchFromGitHub,
pythonOlder,
numpy,
scikit-learn,
scipy,
setuptools,
tabulate,
torch,
tqdm,
flaky,
llvmPackages,
pandas,
pytest-cov-stub,
pytestCheckHook,
safetensors,
transformers,
pythonAtLeast,
}:
buildPythonPackage rec {
pname = "skorch";
version = "1.1.0";
pyproject = true;
src = fetchFromGitHub {
owner = "skorch-dev";
repo = "skorch";
tag = "v${version}";
sha256 = "sha256-f0g/kn3HhvYfGDgLpA7gAnYocJrYqHUq680KrGuoPCQ=";
};
# AttributeError: 'NoneType' object has no attribute 'span' with Python 3.13
# https://github.com/skorch-dev/skorch/issues/1080
disabled = pythonOlder "3.9" || pythonAtLeast "3.13";
build-system = [ setuptools ];
dependencies = [
numpy
pandas
scikit-learn
scipy
tabulate
torch # implicit dependency
tqdm
];
nativeCheckInputs = [
flaky
pytest-cov-stub
pytestCheckHook
safetensors
transformers
];
checkInputs = lib.optionals stdenv.cc.isClang [ llvmPackages.openmp ];
disabledTests = [
# on CPU, these expect artifacts from previous GPU run
"test_load_cuda_params_to_cpu"
# failing tests
"test_pickle_load"
# there is a problem with the compiler selection
"test_fit_and_predict_with_compile"
# "Weights only load failed"
"test_can_be_copied"
"test_pickle"
"test_pickle_save_load"
"test_train_net_after_copy"
"test_weights_restore"
# Reported as flaky
"test_fit_lbfgs_optimizer"
];
disabledTestPaths = [
# tries to download missing HuggingFace data
"skorch/tests/test_dataset.py"
"skorch/tests/test_hf.py"
"skorch/tests/llm/test_llm_classifier.py"
# These tests fail when running in parallel for all platforms with:
# "RuntimeError: The server socket has failed to listen on any local
# network address because they use the same hardcoded port."
# This happens on every platform with sandboxing enabled.
"skorch/tests/test_history.py"
];
pythonImportsCheck = [ "skorch" ];
meta = {
description = "Scikit-learn compatible neural net library using Pytorch";
homepage = "https://skorch.readthedocs.io";
changelog = "https://github.com/skorch-dev/skorch/blob/master/CHANGES.md";
license = lib.licenses.bsd3;
maintainers = with lib.maintainers; [ bcdarwin ];
badPlatforms = [
# Most tests fail with:
# Fatal Python error: Segmentation fault
lib.systems.inspect.patterns.isDarwin
];
};
}