xla
xla copied to clipboard
[torchbench] `Super_SloMo` failing on inference with dynamo after `bfloat16` conversion.
🐛 Bug
After converting the Super_SloMo
model to bfloat16
, running inference on dynamo raises the following error:
python xla/benchmarks/experiment_runner.py \
--suite-name torchbench --accelerator cuda \
--xla PJRT --dynamo openxla --test eval \
--no-resume --print-subprocess \
-k Super_SloMo
Traceback (most recent call last):
File "xla/benchmarks/experiment_runner.py", line 914, in <module>
main()
File "xla/benchmarks/experiment_runner.py", line 910, in main
runner.run()
File "xla/benchmarks/experiment_runner.py", line 59, in run
self.run_single_config()
File "xla/benchmarks/experiment_runner.py", line 254, in run_single_config
metrics, last_output = self.run_once_and_gather_metrics(
File "xla/benchmarks/experiment_runner.py", line 331, in run_once_and_gather_metrics
output, _ = loop(iter_fn=self._default_iter_fn)
File "xla/benchmarks/experiment_runner.py", line 300, in loop
output, timing, trace = iter_fn(benchmark_experiment, benchmark_model,
File "xla/benchmarks/experiment_runner.py", line 216, in _default_iter_fn
output = benchmark_model.model_iter_fn(
File "torch/_dynamo/eval_frame.py", line 454, in _fn
return fn(*args, **kwargs)
File "xla/benchmarks/benchmark_model.py", line 168, in eval
def eval(self, inputs, collect_full_output=False):
File "torch/_dynamo/eval_frame.py", line 454, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "torch/_functorch/aot_autograd.py", line 893, in forward
return compiled_fn(full_args)
File "torch/_functorch/_aot_autograd/utils.py", line 79, in g
return f(*args)
File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 101, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
File "torch/_functorch/_aot_autograd/utils.py", line 103, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper
return compiled_fw(args)
File "torch/_functorch/_aot_autograd/utils.py", line 79, in g
return f(*args)
File "torch/_dynamo/backends/torchxla.py", line 51, in fwd
compiled_graph = bridge.extract_compiled_graph(model, args)
File "xla/torch_xla/core/dynamo_bridge.py", line 543, in extract_compiled_graph
collector.run(*xla_args)
File "torch/fx/interpreter.py", line 145, in run
self.env[node] = self.run_node(node)
File "xla/torch_xla/core/dynamo_bridge.py", line 431, in run_node
result = super().run_node(n)
File "torch/fx/interpreter.py", line 202, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "torch/fx/interpreter.py", line 274, in call_function
return target(*args, **kwargs)
File "torch/_ops.py", line 571, in __call__
return self_._op(*args, **kwargs)
RuntimeError: expected scalar type Float but found Half
While executing %grid_sampler_2d_4 : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%arg114_1, %stack_4, 0, 0, False), kwargs = {})
Original traceback:
File "xla/benchmarks/benchmark_model.py", line 170, in eval
pred = self.module(*inputs)
File "torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "benchmark/torchbenchmark/models/Super_SloMo/model_wrapper.py", line 68, in forward
warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(self.trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(self.trainFlowBackWarp(I1, F_0_1), I0)
File "torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "benchmark/torchbenchmark/models/Super_SloMo/slomo_model.py", line 288, in forward
imgOut = torch.nn.functional.grid_sample(img, grid)
Affected Configurations
- Dynamo Inference
- Dynamo Training
Environment
- Reproducible on XLA backend [CPU/TPU]: CUDA
- torch_xla version: 20692cb04256ca24b1dd2c7d00ffba0bb0cb69ac
cc @miladm @JackCaoG