pytorch3d icon indicating copy to clipboard operation
pytorch3d copied to clipboard

Does PyTorch3D support autocast?

Open tejank10 opened this issue 4 years ago • 4 comments

Hi, I was wondering if current release of PyTorch3D intends to support the latest amp and autocast features of PyTorch 1.6. I tried rendering a mesh with autocast enabled, but it was giving the following error. Is the current supposed to support it? If not, then is this feature planned for future releases?

Thanks.

  File "demo.py", line 158, in <module>
    pyrendering = renderer(mesh)
  File "/home/vikrant/CMR_expts/cmr_amp_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/vikrant/pytorch3d/pytorch3d/renderer/mesh/renderer.py", line 48, in forward
    fragments = self.rasterizer(meshes_world, **kwargs)
  File "/home/vikrant/CMR_expts/cmr_amp_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/vikrant/pytorch3d/pytorch3d/renderer/mesh/rasterizer.py", line 120, in forward
    meshes_screen = self.transform(meshes_world, **kwargs)
  File "/home/vikrant/pytorch3d/pytorch3d/renderer/mesh/rasterizer.py", line 102, in transform
    verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
  File "/home/vikrant/pytorch3d/pytorch3d/renderer/cameras.py", line 149, in get_world_to_view_transform
    world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T)
  File "/home/vikrant/pytorch3d/pytorch3d/renderer/cameras.py", line 839, in get_world_to_view_transform
    R = Rotate(R, device=R.device)
  File "/home/vikrant/pytorch3d/pytorch3d/transforms/transform3d.py", line 511, in __init__
    _check_valid_rotation_matrix(R, tol=orthogonal_tol)
  File "/home/vikrant/pytorch3d/pytorch3d/transforms/transform3d.py", line 702, in _check_valid_rotation_matrix
    orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol)
RuntimeError: Half did not match Float

tejank10 avatar Aug 03 '20 12:08 tejank10

Pytorch3D has not been tested with autocast or designed with autocast in mind. The custom pytorch3d ops (written in cuda and c++) could potentially have problems with autocast. ~~But I would expect the pure python parts of pytorch3d to just work. If I understand autocast correctly, there is a problem with the allclose function in pytorch, which is a pytorch bug, and that is the immediate source of your error.~~

from torch.cuda.amp import autocast
import torch

with autocast(): 
    a=torch.zeros((1,4,4),device="cuda:0") 
    c=torch.bmm(a, a) 
    print(c.dtype) # => half
    print(torch.allclose(a,c)) # => RuntimeError: Float did not match Half

EDIT: Reading the autpcast doc more closely, it isn't clear that unlisted ops will work with autocast. This example and pytorch3d in general are using code which autocast doesn't have to support.

bottler avatar Aug 03 '20 14:08 bottler

We will keep this issue open as an enhancement feature we can consider for future.

nikhilaravi avatar Aug 05 '20 23:08 nikhilaravi

I got a RuntimeError expected scalar type Double but found Float when I do:

images = renderer(mesh, lights=lights, materials=materials, cameras=cameras)

My data is in torch double, the R and T for the LookAt view transformation is also torch double. I have no idea what else I could do to prevent a float input....Is this also something related to autocast?

RuntimeError Traceback (most recent call last) in ----> 1 images = renderer(mesh, lights=lights, materials=materials, cameras=cameras)

~/anaconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(),

~/anaconda3/envs/pytorch3d/lib/python3.8/site-packages/pytorch3d/renderer/mesh/renderer.py in forward(self, meshes_world, **kwargs) 46 the range for the corresponding face. 47 """ ---> 48 fragments = self.rasterizer(meshes_world, **kwargs) 49 images = self.shader(fragments, meshes_world, **kwargs) 50

~/anaconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(),

~/anaconda3/envs/pytorch3d/lib/python3.8/site-packages/pytorch3d/renderer/mesh/rasterizer.py in forward(self, meshes_world, **kwargs) 124 Fragments: Rasterization outputs as a named tuple. 125 """ --> 126 meshes_screen = self.transform(meshes_world, **kwargs) 127 raster_settings = kwargs.get("raster_settings", self.raster_settings) 128

~/anaconda3/envs/pytorch3d/lib/python3.8/site-packages/pytorch3d/renderer/mesh/rasterizer.py in transform(self, meshes_world, **kwargs) 106 # TODO: Revisit whether or not to transform z coordinate to [-1, 1] or 107 # [0, 1] range. --> 108 verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points( 109 verts_world 110 )

~/anaconda3/envs/pytorch3d/lib/python3.8/site-packages/pytorch3d/transforms/transform3d.py in transform_points(self, points, eps) 311 312 composed_matrix = self.get_matrix() --> 313 points_out = _broadcast_bmm(points_batch, composed_matrix) 314 denom = points_out[..., 3:] # denominator 315 if eps is not None:

~/anaconda3/envs/pytorch3d/lib/python3.8/site-packages/pytorch3d/transforms/transform3d.py in _broadcast_bmm(a, b) 679 if len(b) == 1: 680 b = b.expand(len(a), -1, -1) --> 681 return a.bmm(b) 682 683

RuntimeError: expected scalar type Double but found Float

Yumin-Sun-00 avatar Oct 27 '20 16:10 Yumin-Sun-00

@Yumin-Sun-00 Is your question about autocast? If so, I am confused because why would you use autocast with double precision data? If not, please open a new issue and follow the guidelines for new issues - like giving the complete code.

Note that much floating point data is generally expected to be single precision (FloatTensor) in pytorch3d.

bottler avatar Oct 27 '20 18:10 bottler