xla icon indicating copy to clipboard operation
xla copied to clipboard

[torchbench] `Super_SloMo` failing on inference with dynamo after `bfloat16` conversion.

Open ysiraichi opened this issue 1 year ago • 0 comments

🐛 Bug

After converting the Super_SloMo model to bfloat16, running inference on dynamo raises the following error:

python xla/benchmarks/experiment_runner.py \
    --suite-name torchbench --accelerator cuda \
    --xla PJRT --dynamo openxla --test eval \
    --no-resume --print-subprocess \
    -k Super_SloMo
Traceback (most recent call last):
  File "xla/benchmarks/experiment_runner.py", line 914, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 910, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 59, in run
    self.run_single_config()
  File "xla/benchmarks/experiment_runner.py", line 254, in run_single_config
    metrics, last_output = self.run_once_and_gather_metrics(
  File "xla/benchmarks/experiment_runner.py", line 331, in run_once_and_gather_metrics
    output, _ = loop(iter_fn=self._default_iter_fn)
  File "xla/benchmarks/experiment_runner.py", line 300, in loop
    output, timing, trace = iter_fn(benchmark_experiment, benchmark_model,
  File "xla/benchmarks/experiment_runner.py", line 216, in _default_iter_fn
    output = benchmark_model.model_iter_fn(
  File "torch/_dynamo/eval_frame.py", line 454, in _fn
    return fn(*args, **kwargs)
  File "xla/benchmarks/benchmark_model.py", line 168, in eval
    def eval(self, inputs, collect_full_output=False):
  File "torch/_dynamo/eval_frame.py", line 454, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "torch/_functorch/aot_autograd.py", line 893, in forward
    return compiled_fn(full_args)
  File "torch/_functorch/_aot_autograd/utils.py", line 79, in g
    return f(*args)
  File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 101, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "torch/_functorch/_aot_autograd/utils.py", line 103, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "torch/_functorch/_aot_autograd/utils.py", line 79, in g
    return f(*args)
  File "torch/_dynamo/backends/torchxla.py", line 51, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "xla/torch_xla/core/dynamo_bridge.py", line 543, in extract_compiled_graph
    collector.run(*xla_args)
  File "torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "xla/torch_xla/core/dynamo_bridge.py", line 431, in run_node
    result = super().run_node(n)
  File "torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "torch/fx/interpreter.py", line 274, in call_function
    return target(*args, **kwargs)
  File "torch/_ops.py", line 571, in __call__
    return self_._op(*args, **kwargs)
RuntimeError: expected scalar type Float but found Half

While executing %grid_sampler_2d_4 : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%arg114_1, %stack_4, 0, 0, False), kwargs = {})
Original traceback:
  File "xla/benchmarks/benchmark_model.py", line 170, in eval
    pred = self.module(*inputs)
  File "torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "benchmark/torchbenchmark/models/Super_SloMo/model_wrapper.py", line 68, in forward
    warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(self.trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(self.trainFlowBackWarp(I1, F_0_1), I0)
  File "torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "benchmark/torchbenchmark/models/Super_SloMo/slomo_model.py", line 288, in forward
    imgOut = torch.nn.functional.grid_sample(img, grid)

Affected Configurations

  • Dynamo Inference
  • Dynamo Training

Environment

  • Reproducible on XLA backend [CPU/TPU]: CUDA
  • torch_xla version: 20692cb04256ca24b1dd2c7d00ffba0bb0cb69ac

cc @miladm @JackCaoG

ysiraichi avatar Feb 16 '24 19:02 ysiraichi