ao
ao copied to clipboard
[ROCm] float8 does not work
Hi @hongxiayang @hliuca ,
It seems like float8 training using torchao.float8
is not support at the moment. Is there a different library or code path I should be using for float8 training or what the timelines around ROCm supporting torchao.float8
?
Attempting Install From Nightly
From using the ROCm nightly torchao wheel, the torchao.float8
module is not present
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/rocm6.2
python -c "import torchao; print(dir(torchao))"
['__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'apply_dynamic_quant', 'apply_weight_only_int8_quant', 'dtypes', 'kernel', 'quantization']
Attempting Install From Source
From installing from source, I run into an triton datatype issue. If I disable torch.compile, then i run into the eager mode fp8 dtype not being the AMD format but the Nvidia fp8 format.
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
pip install git+https://github.com/pytorch/ao.git
Eager Mode Error
tensor_out = addmm_float8_unwrapped(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torchao/float8/float8_python_api.py", line 55, in addmm_float8_unwrapped
output = torch._scaled_mm(
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/hip/HIPDataType.h":102, please report a bug to PyTorch. Cannot convert ScalarType Float8_e4m3fn to hipDataType.
Compile Mode Error
tmp15 = 448.0
tmp16 = triton_helpers.minimum(tmp14, tmp15)
tmp17 = tmp16.to(tl.float8e4nv)
^
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Reprod Script is From The torchao.float8 README Example
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
# create model and sample input
m = nn.Sequential(
nn.Linear(2048, 4096),
nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
# enable torch.compile for competitive performance
m = torch.compile(m)
# toy training loop
for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()