torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Fx importer does not support bfloat16

Open kumardeepakamd opened this issue 1 year ago • 1 comments

Error seen:

Traceback (most recent call last):
  File "/proj/gdba/kumar/nod/SHARK-TestSuite/e2eshark/t-r-bf16-direct-fx-importer/pytorch/combinations/mlp/runmodel.py", line 131, in <module>
    torch_mlir_model = export_and_import(model, test_input)
  File "/proj/gdba/kumar/nod/SHARK-TestSuite/e2eshark/t-r-bf16-direct-fx-importer/pytorch/combinations/mlp/runmodel.py", line 59, in export_and_import
    fx_importer.import_frozen_exported_program(prog)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 351, in import_frozen_exported_program
    self.import_stateless_graph(g)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 377, in import_stateless_graph
    node_importer.import_nodes(g.nodes)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 620, in import_nodes
    self._import_torch_op_overload(loc, node, target)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 798, in _import_torch_op_overload
    self._import_argument(loc, node.args[i], parameter.type)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 860, in _import_argument
    return self._import_literal(arg)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 877, in _import_literal
    return converter(py_value, self, self._cc)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 1169, in <lambda>
    lambda arg, gni, cc: _make_vtensor_literal_op(
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 1006, in _make_vtensor_literal_op
    npy_dtype is not None
AssertionError: Can not create literal tensor for unsupported datatype: torch.bfloat16

Steps to Reproduce: Make sure you have a python env with torch-mlir package installed. Save following file as model.py

import torch
import torch.nn as nn

# Fx importer related
from typing import Optional
import torch.export
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d


def export_and_import(
    f,
    *args,
    fx_importer: Optional[FxImporter] = None,
    constraints: Optional[torch.export.Constraint] = None,
    **kwargs,
):
    context = ir.Context()
    torch_d.register_dialect(context)

    if fx_importer is None:
        fx_importer = FxImporter(context=context)
    prog = torch.export.export(f, args, kwargs, constraints=constraints)
    fx_importer.import_frozen_exported_program(prog)
    return fx_importer.module_op


class mlp(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            # 3 input, 4 output
            nn.Linear(3, 4),
            nn.ReLU(),
            # 3 input, 5 output
            nn.Linear(4, 5),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)

model = mlp()
test_input = torch.randn(8, 3)
test_output = model(test_input)
print("Input:", test_input)
print("Output:", test_output)
model = model.to(torch.bfloat16)
test_input = test_input.to(torch.bfloat16)
torch_mlir_model = export_and_import(model, test_input)
with open("mlp.torch.mlir", "w+") as f:
    f.write(torch_mlir_model.operation.get_asm())

Run:

python ./model.py

kumardeepakamd avatar Jan 31 '24 14:01 kumardeepakamd

Any work started on this?

kumardeepakamd avatar Feb 09 '24 16:02 kumardeepakamd