torch-mlir
torch-mlir copied to clipboard
Fx importer does not support bfloat16
Error seen:
Traceback (most recent call last):
File "/proj/gdba/kumar/nod/SHARK-TestSuite/e2eshark/t-r-bf16-direct-fx-importer/pytorch/combinations/mlp/runmodel.py", line 131, in <module>
torch_mlir_model = export_and_import(model, test_input)
File "/proj/gdba/kumar/nod/SHARK-TestSuite/e2eshark/t-r-bf16-direct-fx-importer/pytorch/combinations/mlp/runmodel.py", line 59, in export_and_import
fx_importer.import_frozen_exported_program(prog)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 351, in import_frozen_exported_program
self.import_stateless_graph(g)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 377, in import_stateless_graph
node_importer.import_nodes(g.nodes)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 620, in import_nodes
self._import_torch_op_overload(loc, node, target)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 798, in _import_torch_op_overload
self._import_argument(loc, node.args[i], parameter.type)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 860, in _import_argument
return self._import_literal(arg)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 877, in _import_literal
return converter(py_value, self, self._cc)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 1169, in <lambda>
lambda arg, gni, cc: _make_vtensor_literal_op(
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 1006, in _make_vtensor_literal_op
npy_dtype is not None
AssertionError: Can not create literal tensor for unsupported datatype: torch.bfloat16
Steps to Reproduce: Make sure you have a python env with torch-mlir package installed. Save following file as model.py
import torch
import torch.nn as nn
# Fx importer related
from typing import Optional
import torch.export
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
constraints: Optional[torch.export.Constraint] = None,
**kwargs,
):
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context)
prog = torch.export.export(f, args, kwargs, constraints=constraints)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op
class mlp(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
# 3 input, 4 output
nn.Linear(3, 4),
nn.ReLU(),
# 3 input, 5 output
nn.Linear(4, 5),
nn.ReLU(),
)
def forward(self, x):
return self.layers(x)
model = mlp()
test_input = torch.randn(8, 3)
test_output = model(test_input)
print("Input:", test_input)
print("Output:", test_output)
model = model.to(torch.bfloat16)
test_input = test_input.to(torch.bfloat16)
torch_mlir_model = export_and_import(model, test_input)
with open("mlp.torch.mlir", "w+") as f:
f.write(torch_mlir_model.operation.get_asm())
Run:
python ./model.py
Any work started on this?