TransformerEngine
TransformerEngine copied to clipboard
Export to ONNX fails
Docker image nvcr.io/nvidia/pytorch:23.09-py3
with installation of transformer-engine==1.0.0+66d91d5
.
Exporting a TransformerLayer
to ONNX format fails. Minimal reproducible example below. Am I doing something wrong?
import transformer_engine
import torch
# # Error message suggest to add these 2 lines, but it does not help
import torch._dynamo
torch._dynamo.config.suppress_errors = True
te_layer = transformer_engine.pytorch.TransformerLayer(
hidden_size=16,
ffn_hidden_size=32,
num_attention_heads=4,
attention_dropout=0.0,
hidden_dropout=0.0,
layer_type="encoder",
)
# hidden state [sq, b, h]
te_layer = te_layer.to("cuda")
embed = torch.randn(64, 2, 16).to("cuda")
# mask [b, np, sq, sk]
mask = torch.ones(2, 1, 1, 64).bool().to("cuda")
te_layer(embed, mask, self_attn_mask_type="padding")
print("forward done")
with torch.inference_mode():
with transformer_engine.pytorch.onnx_export(enabled=True):
torch.onnx.export(
te_layer, (embed, mask, {"self_attn_mask_type": "padding"}), "model.onnx"
)
print("ONNX done")
Output
forward done
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[/home/inceptive/test.py](https://untitled+.vscode-resource.vscode-cdn.net/home/inceptive/test.py) in line 30
[28](file:///home/inceptive/test.py?line=27) with torch.inference_mode():
[29](file:///home/inceptive/test.py?line=28) with transformer_engine.pytorch.onnx_export(enabled=True):
---> [30](file:///home/inceptive/test.py?line=29) torch.onnx.export(
[31](file:///home/inceptive/test.py?line=30) te_layer, (embed, mask, {"self_attn_mask_type": "padding"}), "model.onnx"
[32](file:///home/inceptive/test.py?line=31) )
[33](file:///home/inceptive/test.py?line=32) print("ONNX done")
File [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:516](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:516), in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
[189](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=188) @_beartype.beartype
[190](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=189) def export(
[191](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=190) model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
(...)
[208](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=207) autograd_inlining: Optional[bool] = True,
[209](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=208) ) -> None:
[210](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=209) r"""Exports a model into ONNX format.
[211](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=210)
[212](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=211) If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
(...)
[513](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=512) All errors are subclasses of :class:`errors.OnnxExporterError`.
[514](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=513) """
--> [516](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=515) _export(
[517](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=516) model,
[518](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=517) args,
[519](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=518) f,
[520](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=519) export_params,
[521](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=520) verbose,
[522](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=521) training,
[523](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=522) input_names,
[524](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=523) output_names,
[525](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=524) operator_export_type=operator_export_type,
[526](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=525) opset_version=opset_version,
[527](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=526) do_constant_folding=do_constant_folding,
[528](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=527) dynamic_axes=dynamic_axes,
[529](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=528) keep_initializers_as_inputs=keep_initializers_as_inputs,
[530](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=529) custom_opsets=custom_opsets,
[531](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=530) export_modules_as_functions=export_modules_as_functions,
[532](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=531) autograd_inlining=autograd_inlining,
[533](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=532) )
File [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1582](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1582), in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
[1579](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1578) dynamic_axes = {}
[1580](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1579) _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> [1582](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1581) graph, params_dict, torch_out = _model_to_graph(
[1583](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1582) model,
[1584](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1583) args,
[1585](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1584) verbose,
[1586](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1585) input_names,
[1587](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1586) output_names,
[1588](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1587) operator_export_type,
[1589](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1588) val_do_constant_folding,
[1590](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1589) fixed_batch_size=fixed_batch_size,
[1591](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1590) training=training,
[1592](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1591) dynamic_axes=dynamic_axes,
[1593](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1592) )
[1595](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1594) # TODO: Don't allocate a in-memory string for the protobuf
[1596](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1595) defer_weight_export = (
[1597](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1596) export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
[1598](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1597) )
File [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1135](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1135), in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
[1132](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1131) args = (args,)
[1134](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1133) model = _pre_trace_quant_model(model, args)
-> [1135](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1134) graph, params, torch_out, module = _create_jit_graph(model, args)
[1136](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1135) params_dict = _get_named_param_dict(graph, params)
[1138](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1137) try:
File [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1011](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1011), in _create_jit_graph(model, args)
[1006](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1005) graph = _C._propagate_and_assign_input_shapes(
[1007](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1006) graph, flattened_args, param_count_list, False, False
[1008](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1007) )
[1009](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1008) return graph, params, torch_out, None
-> [1011](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1010) graph, torch_out = _trace_and_get_graph_from_model(model, args)
[1012](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1011) _C._jit_pass_onnx_lint(graph)
[1013](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=1012) state_dict = torch.jit._unique_state_dict(model)
File [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:915](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:915), in _trace_and_get_graph_from_model(model, args)
[913](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=912) prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
[914](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=913) torch.set_autocast_cache_enabled(False)
--> [915](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=914) trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
[916](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=915) model,
[917](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=916) args,
[918](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=917) strict=False,
[919](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=918) _force_outplace=False,
[920](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=919) _return_inputs_states=True,
[921](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=920) )
[922](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=921) torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
[924](file:///usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py?line=923) warn_on_static_input_change(inputs_states)
File [/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:333](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:333), in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
[331](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=330) dynamic_ctx.__enter__()
[332](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=331) try:
--> [333](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=332) return fn(*args, **kwargs)
[334](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=333) finally:
[335](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=334) set_eval_frame(prior)
File [/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py:17](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py:17), in wrap_inline.<locals>.inner(*args, **kwargs)
[15](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py?line=14) @functools.wraps(fn)
[16](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py?line=15) def inner(*args, **kwargs):
---> [17](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py?line=16) return fn(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:1287](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:1287), in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
[1285](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1284) if not isinstance(args, tuple):
[1286](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1285) args = (args,)
-> [1287](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1286) outs = ONNXTracedModule(
[1288](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1287) f, strict, _force_outplace, return_inputs, _return_inputs_states
[1289](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1288) )(*args, **kwargs)
[1290](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=1289) return outs
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
[1516](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1515) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1517](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1516) else:
-> [1518](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1517) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
[1522](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1521) # If we don't have any hooks, we want to skip the rest of the logic in
[1523](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1522) # this function, and just call forward.
[1524](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1523) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1525](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1524) or _global_backward_pre_hooks or _global_backward_hooks
[1526](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1525) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1526) return forward_call(*args, **kwargs)
[1529](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1528) try:
[1530](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1529) result = None
File [/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:133](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:133), in ONNXTracedModule.forward(self, *args)
[130](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=129) else:
[131](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=130) return tuple(out_vars)
--> [133](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=132) graph, out = torch._C._create_graph_by_tracing(
[134](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=133) wrapper,
[135](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=134) in_vars + module_state,
[136](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=135) _create_interpreter_name_lookup_fn(),
[137](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=136) self.strict,
[138](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=137) self._force_outplace,
[139](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=138) )
[141](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=140) if self._return_inputs:
[142](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=141) return graph, outs[0], ret_inputs[0]
File [/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:124](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:124), in ONNXTracedModule.forward.<locals>.wrapper(*args)
[122](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=121) if self._return_inputs_states:
[123](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=122) inputs_states.append(_unflatten(in_args, in_desc))
--> [124](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=123) outs.append(self.inner(*trace_inputs))
[125](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=124) if self._return_inputs_states:
[126](file:///usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py?line=125) inputs_states[0] = (inputs_states[0], trace_inputs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
[1516](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1515) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1517](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1516) else:
-> [1518](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1517) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
[1522](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1521) # If we don't have any hooks, we want to skip the rest of the logic in
[1523](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1522) # this function, and just call forward.
[1524](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1523) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1525](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1524) or _global_backward_pre_hooks or _global_backward_hooks
[1526](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1525) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1526) return forward_call(*args, **kwargs)
[1529](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1528) try:
[1530](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1529) result = None
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508), in Module._slow_forward(self, *input, **kwargs)
[1506](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1505) recording_scopes = False
[1507](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1506) try:
-> [1508](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1507) result = self.forward(*input, **kwargs)
[1509](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1508) finally:
[1510](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1509) if recording_scopes:
File [~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py:566](https://untitled+.vscode-resource.vscode-cdn.net//~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py:566), in TransformerLayer.forward(self, hidden_states, attention_mask, self_attn_mask_type, encoder_output, enc_dec_attn_mask, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, fast_zero_fill)
[561](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=560) hidden_states = cast_if_needed(
[562](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=561) hidden_states, torch.get_autocast_gpu_dtype()
[563](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=562) )
[565](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=564) # Self attention.
--> [566](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=565) self_attention_outputs = self.self_attention(
[567](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=566) hidden_states,
[568](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=567) attention_mask=attention_mask,
[569](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=568) attn_mask_type=self_attn_mask_type,
[570](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=569) inference_params=inference_params,
[571](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=570) is_first_microbatch=is_first_microbatch,
[572](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=571) checkpoint_core_attention=checkpoint_core_attention,
[573](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=572) rotary_pos_emb=rotary_pos_emb,
[574](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=573) core_attention_bias_type=core_attention_bias_type,
[575](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=574) core_attention_bias=core_attention_bias,
[576](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=575) fast_zero_fill=fast_zero_fill,
[577](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=576) )
[579](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=578) if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
[580](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/transformer.py?line=579) attention_output, attention_bias, residual = self_attention_outputs
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
[1516](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1515) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1517](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1516) else:
-> [1518](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1517) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
[1522](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1521) # If we don't have any hooks, we want to skip the rest of the logic in
[1523](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1522) # this function, and just call forward.
[1524](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1523) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1525](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1524) or _global_backward_pre_hooks or _global_backward_hooks
[1526](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1525) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1526) return forward_call(*args, **kwargs)
[1529](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1528) try:
[1530](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1529) result = None
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508), in Module._slow_forward(self, *input, **kwargs)
[1506](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1505) recording_scopes = False
[1507](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1506) try:
-> [1508](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1507) result = self.forward(*input, **kwargs)
[1509](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1508) finally:
[1510](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1509) if recording_scopes:
File [~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py:2706](https://untitled+.vscode-resource.vscode-cdn.net//~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py:2706), in MultiheadAttention.forward(self, hidden_states, attention_mask, encoder_output, attn_mask_type, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, fast_zero_fill)
[2703](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2702) if self.attention_type == "self":
[2704](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2703) # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
[2705](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2704) if self.input_layernorm:
-> [2706](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2705) layernorm_qkv_outputs = self.layernorm_qkv(
[2707](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2706) hidden_states,
[2708](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2707) is_first_microbatch=is_first_microbatch,
[2709](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2708) )
[2710](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2709) if self.return_layernorm_output:
[2711](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py?line=2710) mixed_x_layer, layernorm_output = layernorm_qkv_outputs
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
[1516](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1515) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1517](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1516) else:
-> [1518](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1517) return self._call_impl(*args, **kwargs)
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
[1522](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1521) # If we don't have any hooks, we want to skip the rest of the logic in
[1523](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1522) # this function, and just call forward.
[1524](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1523) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1525](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1524) or _global_backward_pre_hooks or _global_backward_hooks
[1526](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1525) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1526) return forward_call(*args, **kwargs)
[1529](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1528) try:
[1530](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1529) result = None
File [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1508), in Module._slow_forward(self, *input, **kwargs)
[1506](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1505) recording_scopes = False
[1507](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1506) try:
-> [1508](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1507) result = self.forward(*input, **kwargs)
[1509](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1508) finally:
[1510](file:///usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py?line=1509) if recording_scopes:
File [/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:333](https://untitled+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:333), in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
[331](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=330) dynamic_ctx.__enter__()
[332](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=331) try:
--> [333](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=332) return fn(*args, **kwargs)
[334](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=333) finally:
[335](file:///usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py?line=334) set_eval_frame(prior)
File [~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py:977](https://untitled+.vscode-resource.vscode-cdn.net//~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py:977), in LayerNormLinear.forward(self, inp, is_first_microbatch)
[943](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=942) args = [None]
[944](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=943) args += (
[945](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=944) inp,
[946](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=945) self.layer_norm_weight,
(...)
[975](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=974) self.ub_atomic_gemm_ag,
[976](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=975) )
--> [977](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=976) out = fwd_fn(*args)
[979](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=978) if self.return_layernorm_output:
[980](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=979) out, ln_out = out
File [~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py:93](https://untitled+.vscode-resource.vscode-cdn.net//~/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py:93), in _LayerNormLinear.forward(ctx, inp, ln_weight, ln_bias, weight, weight_fp8, weight_t_fp8, bias, use_bias, eps, is_first_microbatch, fp8, fp8_calibration, fp8_meta, fuse_wgrad_accumulation, tp_group, tp_size, sequence_parallel, tensor_parallel, activation_dtype, parallel_mode, return_layernorm_output, is_grad_enabled, fwd_ln_sm_margin, bwd_ln_sm_margin, zero_centered_gamma, normalization, primary_weights_in_fp8, ub_bulk_wgrad, ub_bulk_dgrad, ub_split_ag, ub_atomic_gemm_ag)
[91](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=90) in_features = ln_weight.numel()
[92](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=91) assert inp.shape[-1] == in_features, "GEMM not possible"
---> [93](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=92) inputmat = inp.view((-1, in_features))
[94](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=93) if fp8:
[95](file:///home/inceptive/.local/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_linear.py?line=94) assert_dim_for_fp8_exec(inputmat)
RuntimeError: size == list_trace.size() INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/jit/frontend/tracer.cpp":1014, please report a bug to PyTorch.
Did anyone manage to reproduce this bug? Happy to provide more context if needed.