oneflow icon indicating copy to clipboard operation
oneflow copied to clipboard

fp16 division is inconsistent with torch

Open ofhwei opened this issue 3 years ago • 1 comments

Summary

fp16除法操作返回结果类型和torch不一致

Code to reproduce bug

>>>import oneflow as flow
>>>a = flow.randn(3, 3, dtype=flow.float16).cuda()
>>>b = flow.randn(3, 3, dtype=flow.float16).cuda()
>>>a/b
tensor([[-2.1495e-03,  1.5983e+00, -5.2973e-01],
        [-1.7968e-01, -4.0361e+00,  5.4459e-01],
        [ 5.0794e+00,  5.5164e-01,  1.2128e+00]], device='cuda:0', dtype=oneflow.float32)
>>>import torch
>>>a = torch.randn(3, 3, dtype=torch.float16).cuda()
>>>b = torch.randn(3, 3, dtype=torch.float16).cuda()
>>>a/b
tensor([[  2.7461,   2.4453,   0.3870],
        [  1.8223,  -0.7861,  -0.8438],
        [  0.3005,  -0.0428, -19.0156]], device='cuda:0', dtype=torch.float16)

System Information

  • What is your OneFlow installation (pip, source, dockerhub): source
  • OS: linux
  • OneFlow version (run python3 -m oneflow --doctor): version: 0.8.1+cu117.git.f6852c1135 git_commit: f6852c1135 cmake_build_type: RelWithDebInfo rdma: False mlir: False

ofhwei avatar Nov 16 '22 08:11 ofhwei

感觉 autotest 的 random_tensor 可以把 dtype 也随机

daquexian avatar Nov 16 '22 08:11 daquexian