TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Export to ONNX fails

Open jbcdnr opened this issue 1 year ago • 1 comments

Docker image nvcr.io/nvidia/pytorch:23.09-py3 with installation of transformer-engine==1.0.0+66d91d5.

Exporting a TransformerLayer to ONNX format fails. Minimal reproducible example below. Am I doing something wrong?

import transformer_engine
import torch

# # Error message suggest to add these 2 lines, but it does not help
import torch._dynamo

torch._dynamo.config.suppress_errors = True

te_layer = transformer_engine.pytorch.TransformerLayer(
    hidden_size=16,
    ffn_hidden_size=32,
    num_attention_heads=4,
    attention_dropout=0.0,
    hidden_dropout=0.0,
    layer_type="encoder",
)

# hidden state [sq, b, h]
te_layer = te_layer.to("cuda")
embed = torch.randn(64, 2, 16).to("cuda")
# mask [b, np, sq, sk]
mask = torch.ones(2, 1, 1, 64).bool().to("cuda")

te_layer(embed, mask, self_attn_mask_type="padding")
print("forward done")

with torch.inference_mode():
    with transformer_engine.pytorch.onnx_export(enabled=True):
        torch.onnx.export(
            te_layer, (embed, mask, {"self_attn_mask_type": "padding"}), "model.onnx"
        )
print("ONNX done")

Output

forward done
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[/home/inceptive/test.py](https://untitled+.vscode-resource.vscode-cdn.net/home/inceptive/test.py) in line 30
     [28](file:///home/inceptive/test.py?line=27) with torch.inference_mode():
     [29](file:///home/inceptive/test.py?line=28)     with transformer_engine.pytorch.onnx_export(enabled=True):
---> [30](file:///home/inceptive/test.py?line=29)         torch.onnx.export(
     [31](file:///home/inceptive/test.py?line=30)             te_layer, (embed, mask, {"self_attn_mask_type": "padding"}), "model.onnx"
     [32](file:///home/inceptive/test.py?line=31)         )
     [33](file:///home/inceptive/test.py?line=32) print("ONNX done")

File [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:516](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:516), in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
    [189](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=188) @_beartype.beartype
    [190](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=189) def export(
    [191](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=190)     model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
   (...)
    [208](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=207)     autograd_inlining: Optional[bool] = True,
    [209](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=208) ) -> None:
    [210](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=209)     r"""Exports a model into ONNX format.
    [211](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=210) 
    [212](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=211)     If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
   (...)
    [513](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=512)             All errors are subclasses of :class:`errors.OnnxExporterError`.
    [514](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=513)     """
--> [516](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=515)     _export(
    [517](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=516)         model,
    [518](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=517)         args,
    [519](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=518)         f,
    [520](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=519)         export_params,
    [521](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=520)         verbose,
    [522](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=521)         training,
    [523](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=522)         input_names,
    [524](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=523)         output_names,
    [525](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=524)         operator_export_type=operator_export_type,
    [526](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=525)         opset_version=opset_version,
    [527](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=526)         do_constant_folding=do_constant_folding,
    [528](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=527)         dynamic_axes=dynamic_axes,
    [529](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=528)         keep_initializers_as_inputs=keep_initializers_as_inputs,
    [530](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=529)         custom_opsets=custom_opsets,
    [531](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=530)         export_modules_as_functions=export_modules_as_functions,
    [532](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=531)         autograd_inlining=autograd_inlining,
    [533](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=532)     )

File [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1582](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1582), in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
   [1579](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1578)     dynamic_axes = {}
   [1580](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1579) _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> [1582](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1581) graph, params_dict, torch_out = _model_to_graph(
   [1583](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1582)     model,
   [1584](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1583)     args,
   [1585](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1584)     verbose,
   [1586](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1585)     input_names,
   [1587](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1586)     output_names,
   [1588](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1587)     operator_export_type,
   [1589](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1588)     val_do_constant_folding,
   [1590](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1589)     fixed_batch_size=fixed_batch_size,
   [1591](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1590)     training=training,
   [1592](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1591)     dynamic_axes=dynamic_axes,
   [1593](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1592) )
   [1595](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1594) # TODO: Don't allocate a in-memory string for the protobuf
   [1596](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1595) defer_weight_export = (
   [1597](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1596)     export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
   [1598](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1597) )

File [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1135](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1135), in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   [1132](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1131)     args = (args,)
   [1134](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1133) model = _pre_trace_quant_model(model, args)
-> [1135](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1134) graph, params, torch_out, module = _create_jit_graph(model, args)
   [1136](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1135) params_dict = _get_named_param_dict(graph, params)
   [1138](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1137) try:

File [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1011](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1011), in _create_jit_graph(model, args)
   [1006](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1005)     graph = _C._propagate_and_assign_input_shapes(
   [1007](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1006)         graph, flattened_args, param_count_list, False, False
   [1008](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1007)     )
   [1009](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1008)     return graph, params, torch_out, None
-> [1011](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1010) graph, torch_out = _trace_and_get_graph_from_model(model, args)
   [1012](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1011) _C._jit_pass_onnx_lint(graph)
   [1013](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1012) state_dict = torch.jit._unique_state_dict(model)

File [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:915](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:915), in _trace_and_get_graph_from_model(model, args)
    [913](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=912) prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
    [914](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=913) torch.set_autocast_cache_enabled(False)
--> [915](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=914) trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
    [916](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=915)     model,
    [917](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=916)     args,
    [918](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=917)     strict=False,
    [919](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=918)     _force_outplace=False,
    [920](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=919)     _return_inputs_states=True,
    [921](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=920) )
    [922](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=921) torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
    [924](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=923) warn_on_static_input_change(inputs_states)

File [/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:333](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:333), in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    [331](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=330) dynamic_ctx.__enter__()
    [332](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=331) try:
--> [333](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=332)     return fn(*args, **kwargs)
    [334](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=333) finally:
    [335](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=334)     set_eval_frame(prior)

File [/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py:17](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py:17), in wrap_inline.<locals>.inner(*args, **kwargs)
     [15](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py?line=14) @functools.wraps(fn)
     [16](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py?line=15) def inner(*args, **kwargs):
---> [17](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py?line=16)     return fn(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:1287](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:1287), in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
   [1285](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1284) if not isinstance(args, tuple):
   [1286](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1285)     args = (args,)
-> [1287](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1286) outs = ONNXTracedModule(
   [1288](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1287)     f, strict, _force_outplace, return_inputs, _return_inputs_states
   [1289](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1288) )(*args, **kwargs)
   [1290](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1289) return outs

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1515)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1516) else:
-> [1518](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1517)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1521) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1522) # this function, and just call forward.
   [1524](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1523) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1524)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1525)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1526)     return forward_call(*args, **kwargs)
   [1529](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1528) try:
   [1530](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1529)     result = None

File [/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:133](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:133), in ONNXTracedModule.forward(self, *args)
    [130](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=129)     else:
    [131](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=130)         return tuple(out_vars)
--> [133](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=132) graph, out = torch._C._create_graph_by_tracing(
    [134](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=133)     wrapper,
    [135](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=134)     in_vars + module_state,
    [136](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=135)     _create_interpreter_name_lookup_fn(),
    [137](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=136)     self.strict,
    [138](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=137)     self._force_outplace,
    [139](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=138) )
    [141](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=140) if self._return_inputs:
    [142](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=141)     return graph, outs[0], ret_inputs[0]

File [/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:124](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:124), in ONNXTracedModule.forward.<locals>.wrapper(*args)
    [122](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=121) if self._return_inputs_states:
    [123](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=122)     inputs_states.append(_unflatten(in_args, in_desc))
--> [124](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=123) outs.append(self.inner(*trace_inputs))
    [125](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=124) if self._return_inputs_states:
    [126](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=125)     inputs_states[0] = (inputs_states[0], trace_inputs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1515)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1516) else:
-> [1518](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1517)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1521) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1522) # this function, and just call forward.
   [1524](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1523) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1524)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1525)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1526)     return forward_call(*args, **kwargs)
   [1529](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1528) try:
   [1530](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1529)     result = None

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508), in Module._slow_forward(self, *input, **kwargs)
   [1506](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1505)         recording_scopes = False
   [1507](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1506) try:
-> [1508](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1507)     result = self.forward(*input, **kwargs)
   [1509](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1508) finally:
   [1510](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1509)     if recording_scopes:

File [~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py:566](https://untitled+.vscode-resource.vscode-cdn.net//~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py:566), in TransformerLayer.forward(self, hidden_states, attention_mask, self_attn_mask_type, encoder_output, enc_dec_attn_mask, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, fast_zero_fill)
    [561](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=560)     hidden_states = cast_if_needed(
    [562](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=561)         hidden_states, torch.get_autocast_gpu_dtype()
    [563](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=562)     )
    [565](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=564) # Self attention.
--> [566](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=565) self_attention_outputs = self.self_attention(
    [567](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=566)     hidden_states,
    [568](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=567)     attention_mask=attention_mask,
    [569](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=568)     attn_mask_type=self_attn_mask_type,
    [570](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=569)     inference_params=inference_params,
    [571](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=570)     is_first_microbatch=is_first_microbatch,
    [572](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=571)     checkpoint_core_attention=checkpoint_core_attention,
    [573](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=572)     rotary_pos_emb=rotary_pos_emb,
    [574](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=573)     core_attention_bias_type=core_attention_bias_type,
    [575](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=574)     core_attention_bias=core_attention_bias,
    [576](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=575)     fast_zero_fill=fast_zero_fill,
    [577](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=576) )
    [579](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=578) if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
    [580](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=579)     attention_output, attention_bias, residual = self_attention_outputs

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1515)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1516) else:
-> [1518](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1517)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1521) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1522) # this function, and just call forward.
   [1524](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1523) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1524)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1525)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1526)     return forward_call(*args, **kwargs)
   [1529](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1528) try:
   [1530](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1529)     result = None

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508), in Module._slow_forward(self, *input, **kwargs)
   [1506](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1505)         recording_scopes = False
   [1507](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1506) try:
-> [1508](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1507)     result = self.forward(*input, **kwargs)
   [1509](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1508) finally:
   [1510](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1509)     if recording_scopes:

File [~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py:2706](https://untitled+.vscode-resource.vscode-cdn.net//~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py:2706), in MultiheadAttention.forward(self, hidden_states, attention_mask, encoder_output, attn_mask_type, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, fast_zero_fill)
   [2703](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2702) if self.attention_type == "self":
   [2704](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2703)     # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
   [2705](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2704)     if self.input_layernorm:
-> [2706](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2705)         layernorm_qkv_outputs = self.layernorm_qkv(
   [2707](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2706)             hidden_states,
   [2708](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2707)             is_first_microbatch=is_first_microbatch,
   [2709](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2708)         )
   [2710](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2709)         if self.return_layernorm_output:
   [2711](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2710)             mixed_x_layer, layernorm_output = layernorm_qkv_outputs

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1515)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1516) else:
-> [1518](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1517)     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1521) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1522) # this function, and just call forward.
   [1524](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1523) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1524)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1525)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1526)     return forward_call(*args, **kwargs)
   [1529](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1528) try:
   [1530](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1529)     result = None

File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508), in Module._slow_forward(self, *input, **kwargs)
   [1506](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1505)         recording_scopes = False
   [1507](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1506) try:
-> [1508](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1507)     result = self.forward(*input, **kwargs)
   [1509](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1508) finally:
   [1510](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1509)     if recording_scopes:

File [/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:333](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:333), in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    [331](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=330) dynamic_ctx.__enter__()
    [332](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=331) try:
--> [333](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=332)     return fn(*args, **kwargs)
    [334](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=333) finally:
    [335](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=334)     set_eval_frame(prior)

File [~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py:977](https://untitled+.vscode-resource.vscode-cdn.net//~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py:977), in LayerNormLinear.forward(self, inp, is_first_microbatch)
    [943](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=942)         args = [None]
    [944](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=943)     args += (
    [945](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=944)         inp,
    [946](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=945)         self.layer_norm_weight,
   (...)
    [975](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=974)         self.ub_atomic_gemm_ag,
    [976](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=975)     )
--> [977](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=976)     out = fwd_fn(*args)
    [979](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=978) if self.return_layernorm_output:
    [980](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=979)     out, ln_out = out

File [~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py:93](https://untitled+.vscode-resource.vscode-cdn.net//~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py:93), in _LayerNormLinear.forward(ctx, inp, ln_weight, ln_bias, weight, weight_fp8, weight_t_fp8, bias, use_bias, eps, is_first_microbatch, fp8, fp8_calibration, fp8_meta, fuse_wgrad_accumulation, tp_group, tp_size, sequence_parallel, tensor_parallel, activation_dtype, parallel_mode, return_layernorm_output, is_grad_enabled, fwd_ln_sm_margin, bwd_ln_sm_margin, zero_centered_gamma, normalization, primary_weights_in_fp8, ub_bulk_wgrad, ub_bulk_dgrad, ub_split_ag, ub_atomic_gemm_ag)
     [91](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=90) in_features = ln_weight.numel()
     [92](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=91) assert inp.shape[-1] == in_features, "GEMM not possible"
---> [93](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=92) inputmat = inp.view((-1, in_features))
     [94](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=93) if fp8:
     [95](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=94)     assert_dim_for_fp8_exec(inputmat)

RuntimeError: size == list_trace.size() INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/jit/frontend/tracer.cpp":1014, please report a bug to PyTorch.

jbcdnr avatar Nov 21 '23 16:11 jbcdnr

Did anyone manage to reproduce this bug? Happy to provide more context if needed.

jbcdnr avatar Dec 18 '23 13:12 jbcdnr