mach-nix icon indicating copy to clipboard operation
mach-nix copied to clipboard

cannot use cuda if I pin the version of pytorch, which is not in the nixpkgs

Open jian-lin opened this issue 3 years ago • 4 comments
trafficstars

If I do not pin the version of pytorch, which will use the version in nixpkgs, cuda is available. The flake is as follows:

{
  inputs = {
    mach-nix = {
      url = github:DavHau/mach-nix;
      inputs.nixpkgs.follows = "nixpkgs";
      inputs.pypi-deps-db.follows = "pypi-deps-db";
    };
    nixpkgs.url = "github:NixOS/nixpkgs?rev=4f51fae5bbbb5b232ca340f5836875cc2c0f10cc";
    pypi-deps-db = {
      url = github:DavHau/pypi-deps-db;
      flake = false;
    };
  };

  outputs = inputs: {
    devShell.x86_64-linux =
      let
        pkgs = inputs.nixpkgs.legacyPackages.x86_64-linux;
        myPython = inputs.mach-nix.lib.x86_64-linux.mkPython {
          requirements = ''
            torch # <-- defaults to the version in nixpkgs, which is 1.10
          '';
          python = "python39";
          providers = {
            torch = "nixpkgs";
          };
          overridesPost = [
            (curr: prev: {
              torch = prev.torch.override {
                cudaSupport = true;
              };
            })
          ];
        };
      in
      pkgs.mkShell {
        packages = [
          myPython
        ];
      };
  };
}

However, if I pin the version of pytorch, cuda is not available. The not working flake is as follows:

{
  inputs = {
    mach-nix = {
      url = github:DavHau/mach-nix;
      inputs.nixpkgs.follows = "nixpkgs";
      inputs.pypi-deps-db.follows = "pypi-deps-db";
    };
    nixpkgs.url = "github:NixOS/nixpkgs?rev=4f51fae5bbbb5b232ca340f5836875cc2c0f10cc";
    pypi-deps-db = {
      url = github:DavHau/pypi-deps-db;
      flake = false;
    };
  };

  outputs = inputs: {
    devShell.x86_64-linux =
      let
        pkgs = inputs.nixpkgs.legacyPackages.x86_64-linux;
        myPython = inputs.mach-nix.lib.x86_64-linux.mkPython {
          requirements = ''
            torch==1.9.0 # <-- pin 1.9.0, which is not in the nixpkgs
          '';
          python = "python39";
          providers = {
            torch = "nixpkgs";
          };
          overridesPost = [
            (curr: prev: {
              torch = prev.torch.override {
                cudaSupport = true;
              };
            })
          ];
        };
      in
      pkgs.mkShell {
        packages = [
          myPython
        ];
      };
  };
}

jian-lin avatar Apr 21 '22 12:04 jian-lin

@jian-lin have you tried to enable cuda globally while importing nixpkgs? I'd guess that that might help. Currently, you only activate cuda for pytorch, but not for its transitive dependencies...

Just try replacing your pkgs = ... with

{
pkgs = import nixpkgs {
  inherit system;
  config = {
    allowUnfree = true;
    cudaSupport = true;
  };
};
}

instead of the overridesPost attribute for mach-nix

wucke13 avatar Aug 10 '22 15:08 wucke13

Thanks @wucke13. I tried you suggestion, but it still doesn't work.

jian-lin avatar Aug 18 '22 16:08 jian-lin

If you use pytorch from pip; it won't have the cuda wrapping.

You can try the good old

pkgs.mkShell {
    packages = [
          myPython
     ];
    shellHook = ''
      export CUDA_PATH=${pkgs.cudaPackages_11_6.cudatoolkit}
      export CUDA_HOME=${pkgs.cudaPackages_11_6.cudatoolkit}
      export LD_LIBRARY_PATH="${pkgs.cudaPackages_11_6..cudatoolkit}/lib"
    '';
}

hack, if the pytorch version it ends up downloading was compiled with cuda support

ryanswrt avatar Aug 19 '22 11:08 ryanswrt

@ryanswrt thanks for the tip.

Actually I do manage to get pytorch with cuda working using the pip provider on nixos with LD_LIBRARY_PATH=/run/opengl-driver/lib ./result/bin/python. This example program works. Maybe because this example doesn't need those cudatoolkit libs?

Another question: is there a way to get the cuda version of pytorch from pip provider before myPython is built? For now, I can only get the cuda version by ./result/bin/python -c "import torch; print(torch.__version__)" after myPython is built. 1.9.0+cu102 means cuda 10.2 is needed. Guess I need to dig into mach-nix now.

jian-lin avatar Aug 19 '22 11:08 jian-lin