ao icon indicating copy to clipboard operation
ao copied to clipboard

[ROCm] float8 does not work

Open OrenLeung opened this issue 4 months ago • 1 comments

Hi @hongxiayang @hliuca ,

It seems like float8 training using torchao.float8 is not support at the moment. Is there a different library or code path I should be using for float8 training or what the timelines around ROCm supporting torchao.float8?

Attempting Install From Nightly

From using the ROCm nightly torchao wheel, the torchao.float8 module is not present

pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/rocm6.2
python -c "import torchao; print(dir(torchao))"
['__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'apply_dynamic_quant', 'apply_weight_only_int8_quant', 'dtypes', 'kernel', 'quantization']

Attempting Install From Source

From installing from source, I run into an triton datatype issue. If I disable torch.compile, then i run into the eager mode fp8 dtype not being the AMD format but the Nvidia fp8 format.

pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
pip install git+https://github.com/pytorch/ao.git

Eager Mode Error

   tensor_out = addmm_float8_unwrapped(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torchao/float8/float8_python_api.py", line 55, in addmm_float8_unwrapped
    output = torch._scaled_mm(
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/hip/HIPDataType.h":102, please report a bug to PyTorch. Cannot convert ScalarType Float8_e4m3fn to hipDataType.

Compile Mode Error

    tmp15 = 448.0
    tmp16 = triton_helpers.minimum(tmp14, tmp15)
    tmp17 = tmp16.to(tl.float8e4nv)
            ^

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Reprod Script is From The torchao.float8 README Example

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training

# create model and sample input
m = nn.Sequential(
    nn.Linear(2048, 4096),
    nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
    # don't convert the last module
    if fqn == "1":
        return False
    # don't convert linear modules with weight dimensions not divisible by 16
    if isinstance(mod, torch.nn.Linear):
        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
            return False
    return True

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# enable torch.compile for competitive performance
m = torch.compile(m)
# toy training loop
for _ in range(10):
    optimizer.zero_grad()
    y = m(x)
    y.sum().backward()
    optimizer.step()

OrenLeung avatar Oct 12 '24 23:10 OrenLeung