pyhf
pyhf copied to clipboard
Enable CUDA support for PyTorch backend
Description
While TensorFlow and JAX will detect out of the box if there is an available GPU for them to run on and utilize it (if the system environment is properly configured for CUDA and cuDNN) PyTorch sets the default device to the CPU and will need to be told explicitly to use devices with CUDA operations enabled (GPUs). This is currently (as of v0.5.4) not supported in pyhf at all.
While this has to be done explicitly, this can at least be done easily with the device keyword in torch.as_tensor
device (
torch.device, optional) – the desired device of returned tensor. Default: ifNone, uses the current device for the default tensor type (seetorch.set_default_tensor_type()).devicewill be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
so hopefully will only need modification to the PyTorch backend astensor method
https://github.com/scikit-hep/pyhf/blob/95ad1516af0c22d851557cfff76628a87829b22a/src/pyhf/tensor/pytorch_backend.py#L180
in addition to setting some new keyword arguments in the backend __init__ methods and pyhf.set_backend. So maybe something like
class pytorch_backend:
"""PyTorch backend for pyhf"""
__slots__ = [
"name",
"precision",
"dtypemap",
"default_do_grad",
"use_cuda",
"device",
]
def __init__(self, **kwargs):
self.name = 'pytorch'
self.precision = kwargs.get('precision', '32b')
self.dtypemap = {
'float': torch.float64 if self.precision == '64b' else torch.float32,
'int': torch.int64 if self.precision == '64b' else torch.int32,
'bool': torch.bool,
}
self.default_do_grad = True
self.use_cuda = kwargs.get("use_gpu", torch.cuda.is_available())
self.device = torch.device("cuda" if self.use_cuda else "cpu")
# ...
def astensor(self, tensor_in, dtype='float'):
# ...
return torch.as_tensor(tensor_in, dtype=dtype, device=self.device)
However, from some quick tests it looks like a bit more work will be needed.
Example failure of tensors being allocated to both CPU and GPU
$ time pyhf cls --backend pytorch HVTWZ_3500.json
Traceback (most recent call last):
File "/home/feickert/.pyenv/versions/pyhf-dev/bin/pyhf", line 33, in <module>
sys.exit(load_entry_point('pyhf', 'console_scripts', 'pyhf')())
File "/home/feickert/.pyenv/versions/3.8.5/envs/pyhf-dev/lib/python3.8/site-packages/click/core.py", line 829, in __call__
return self.main(*args, **kwargs)
File "/home/feickert/.pyenv/versions/3.8.5/envs/pyhf-dev/lib/python3.8/site-packages/click/core.py", line 782, in main
rv = self.invoke(ctx)
File "/home/feickert/.pyenv/versions/3.8.5/envs/pyhf-dev/lib/python3.8/site-packages/click/core.py", line 1259, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/home/feickert/.pyenv/versions/3.8.5/envs/pyhf-dev/lib/python3.8/site-packages/click/core.py", line 1066, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/feickert/.pyenv/versions/3.8.5/envs/pyhf-dev/lib/python3.8/site-packages/click/core.py", line 610, in invoke
return callback(*args, **kwargs)
File "/home/feickert/workarea/pyhf/src/pyhf/cli/infer.py", line 207, in cls
set_backend("pytorch", precision="64b")
File "/home/feickert/workarea/pyhf/src/pyhf/events.py", line 78, in register_wrapper
result = func(*args, **kwargs)
File "/home/feickert/workarea/pyhf/src/pyhf/__init__.py", line 161, in set_backend
events.trigger("tensorlib_changed")()
File "/home/feickert/workarea/pyhf/src/pyhf/events.py", line 21, in __call__
func()(*args, **kwargs)
File "/home/feickert/workarea/pyhf/src/pyhf/interpolators/code4.py", line 137, in _precompute
self.bases_up = tensorlib.einsum(
File "/home/feickert/workarea/pyhf/src/pyhf/tensor/pytorch_backend.py", line 328, in einsum
return torch.einsum(subscripts, operands)
File "/home/feickert/.pyenv/versions/3.8.5/envs/pyhf-dev/lib/python3.8/site-packages/torch/functional.py", line 342, in einsum
return einsum(equation, *_operands)
File "/home/feickert/.pyenv/versions/3.8.5/envs/pyhf-dev/lib/python3.8/site-packages/torch/functional.py", line 344, in einsum
return _VF.einsum(equation, operands) # type: ignore
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
I've started some outline work on branch feat/add-gpu-to-torch, and this work will naturally need to address some parts of Issue #896.
As this will necessarily be an API breaking change this should go into v0.7.0, and might be a good motivation to get it out soon after v0.6.0 is released.
There is some additionally helpful examples in the PyTorch documentation on CUDA semantics that demonstrates the need to specify devices.
Note to self: This also explains why in the preliminary CPU vs. GPU benchmarks for the HVTWZ_3500.json workspace there was no difference observed between the two for PyTorch — the GPU mode was never being enabled and so it was running on CPU both times.
Also reminds me a little of #1145 .
Also reminds me a little of #1145 .
Yeah, resolving this Issue could probably address that too.
:wave: @nhartman94 Moving from Issue #896, this seems to currently be on the roadmap for v0.7.0, but if this is impacting you strongly we can look at moving it to v0.6.2.
Sounds great, thanks :) No, this is not impacting me strongly at all, I just wanted to understand how one backend worked w/ the gpu, I don't have a strong preference between pytorch and jax.
Sounds great, thanks :) No, this is not impacting me strongly at all, I just wanted to understand how one backend worked w/ the gpu, I don't have a strong preference between pytorch and jax.
Cool. Let us know if this changes. Thank you also very much for asking! Questions help us revisit old issues that need action and also help us understand where there's potential pain points, so we really really appreciate them. :)
@condrine if you wanted to look at this with me later in December/January I'd be interested in following up with you. :+1: Thanks for noticing this issue and bringing it back up at the pyhf workshop. :)
@matthewfeickert great! Let me know whenever you are free to look into this. I will float a few code snippets for this here till then. However, I might need some tests to run this on. So if you have any such examples which I can use to test this, or any instructions on how to set one up, please let me know.