xla icon indicating copy to clipboard operation
xla copied to clipboard

[torchbench] Detectron2 benchmarks failing to run.

Open ysiraichi opened this issue 1 year ago • 16 comments

🐛 Bug

After #6296, a few detectron2 benchmarks started failing when using XLA:

python xla/benchmarks/experiment_runner.py \
    --suite-name torchbench --accelerator cuda --repeat 2 \
    --test eval --xla PJRT --dynamo openxla \
    -k detectron2_fasterrcnn_r_50_c4
Traceback (most recent call last):
  File "xla/benchmarks/experiment_runner.py", line 906, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 902, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 59, in run
    self.run_single_config()
  File "xla/benchmarks/experiment_runner.py", line 247, in run_single_config
    metrics, last_output = self.run_once_and_gather_metrics(
  File "xla/benchmarks/experiment_runner.py", line 324, in run_once_and_gather_metrics
    output, _ = loop(iter_fn=self._default_iter_fn)
  File "xla/benchmarks/experiment_runner.py", line 293, in loop
    output, timing, trace = iter_fn(benchmark_experiment, benchmark_model,
  File "xla/benchmarks/experiment_runner.py", line 209, in _default_iter_fn
    output = benchmark_model.model_iter_fn(
  File "xla/benchmarks/benchmark_model.py", line 155, in eval
    pred = self.module(*inputs)
  File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/detectron2/modeling/meta_arch/rcnn.py", line 150, in forward
    return self.inference(batched_inputs)
  File "/lib/python3.8/site-packages/detectron2/modeling/meta_arch/rcnn.py", line 208, in inference
    proposals, _ = self.proposal_generator(images, features, None)
  File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/detectron2/modeling/proposal_generator/rpn.py", line 454, in forward
    pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features)
  File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/detectron2/modeling/proposal_generator/rpn.py", line 175, in forward
    pred_objectness_logits.append(self.objectness_logits(t))
  File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::Half) should be the same

Environment

  • Reproducible on XLA backend [CPU/TPU]: CUDA
  • torch_xla version: 40727e4d367e183baccd9a2ce734ca7632ca09ac
    • But, basically, every other commit since #6296

cc @miladm @JackCaoG

ysiraichi avatar Jan 19 '24 19:01 ysiraichi

This error seems to be from pytorch which is weird.. do we know why bias is fp16?

JackCaoG avatar Jan 19 '24 19:01 JackCaoG

I don't think #6296 is wrong and should be reverted, since I believe it to be the best way to compare against inductor: instantiate the module in the original accelerator, and then move to XLA. That said, I can think of 2 solutions:

  1. (easy) Special-case these models, so that we instantiate only them with XLA device
  2. (hard) Investigate what's actually going on there

Particularly, I believe (2) to be better, so I will focus on doing that.

ysiraichi avatar Jan 19 '24 19:01 ysiraichi

do we know why bias is fp16?

Not really sure. I still have to investigate.

ysiraichi avatar Jan 19 '24 19:01 ysiraichi

But, yes, this seems like PyTorch is doing something weird.

ysiraichi avatar Jan 19 '24 19:01 ysiraichi

I've just confirmed that instantiating the model with XLA device solves the error. i.e. changing the line below with str(self.benchmark_experiment.get_device())

https://github.com/pytorch/xla/blob/423bb0b295319a692ee21787edbff50d07361db7/benchmarks/torchbench_model.py#L233

ysiraichi avatar Jan 19 '24 19:01 ysiraichi

I've just confirmed that instantiating the model with XLA device solves the error. i.e. changing the line below with str(self.benchmark_experiment.get_device())

https://github.com/pytorch/xla/blob/423bb0b295319a692ee21787edbff50d07361db7/benchmarks/torchbench_model.py#L233

I thought you want to instantiate the model on CUDA and then move to XLA?

vanbasten23 avatar Jan 20 '24 03:01 vanbasten23

Yes. I do think that's better. I was just confirming that that was the change that caused these models to break.

ysiraichi avatar Jan 20 '24 13:01 ysiraichi

After further investigation, I found out the issue is due to a combination of 2 factors:

  • The model, as well as the example inputs, are converted to float16
  • The XLA_USE_FP16 environment variable is set

This causes the function AtenFromXlaTensor (at the end of a relu dispatch) to call MaybeUpcastToHostTorchType which converts the float16 result back to a float32 tensor.


Why wasn't it failing before?

Torchbench already converts the model and example inputs to float16 due to the DEFAULT_EVAL_CUDA_PRECISION variable being set for detectron2 models, when the device is actually cuda. However, before #6296 the models were being initialized with something other than cuda (using str(self.benchmark_experiment.get_device()). Thus, the model was never actually converted to float16, avoiding the error.


Possible Solutions

I can think of a few possible solutions:

  1. Do not set XLA_USE_FP16=1, since the model is already being converted to float16
  2. Remember what was the original data-type of the input tensor, before being downcast to float16
  3. Have a set of models INIT_WITH_XLA_DEVICE, special-casing models so that they don't get downcast

Of those, I think (1) is the best one. Maybe we should thrown a specific error for the case when using float16 tensors while XLA_USE_FP16 environment variable is set.

@miladm @JackCaoG @vanbasten23 Let me know what you think.

ysiraichi avatar Jan 21 '24 15:01 ysiraichi

Apparently, after doing (1), I am getting another error:

  File "/lib/python3.8/site-packages/detectron2/modeling/proposal_generator/proposal_utils.py", line 121, in find_top_rpn_proposals
    keep = batched_nms(boxes.tensor, scores_per_img, lvl, nms_thresh)
  File "/lib/python3.8/site-packages/detectron2/layers/nms.py", line 20, in batched_nms
    return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold)
...
  File "/lib/python3.8/site-packages/torchvision/torchvision/ops/boxes.py", line 109, in resume_in__batched_nms_vanilla_at_107
    curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
  File "/lib/python3.8/site-packages/torchvision/torchvision/ops/boxes.py", line 41, in nms
    return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
  File "torch/_ops.py", line 825, in __call__
    return self_._op(*args, **(kwargs or {}))
RuntimeError: dets (Float) should have the same type as scores (Half)

In summary:

  • Detectron2 is casting one of the nms inputs to float32
  • nms execution is falling back to the CPU implementation, which doesn't accept different types
    • Interestingly, the CUDA implementation does accept

Given the problem above, I think, for now, we should: use solution (3) in the short-term. Let me know what you all think.

ysiraichi avatar Jan 21 '24 15:01 ysiraichi

FYI nn.Module.to(torch.float16) has a bug, I opened the issue in https://github.com/pytorch/pytorch/issues/115792. This is why we still have to use XLA_USE_FP16.

JackCaoG avatar Jan 22 '24 18:01 JackCaoG

I see. So, maybe a solution is to pass --precision fp32 when instantiating the benchmark, while having XLA_USE_FP16 set. What do you think?

ysiraichi avatar Jan 22 '24 19:01 ysiraichi

That seems to be a reasonable workaround until upstream fixed the model.to issue(which I think they are wokring on, I saw some PR flowing around).

JackCaoG avatar Jan 22 '24 19:01 JackCaoG

This issue was temporarily fixed by #6389. #6404 details a better fix to this upcasting problem. One of them being the actual problem description on #6403.

ysiraichi avatar Feb 05 '24 14:02 ysiraichi

Apparently, this issue was not due to conversion issues (https://github.com/pytorch/pytorch/issues/115792) as we once thought, but it's a real problem (more details in this comment).

ysiraichi avatar Feb 12 '24 18:02 ysiraichi

@miladm @JackCaoG

Here's what I found when looking into this issue (nms fallbacking to the CPU kernel): even though there's an implementation of nms inside PyTorch/XLA, it appears that the implementation is only hooked up to a Python function in torch_xla/core/functions.py.

Registering an implementation for XLA to the dispatcher should solve this problem. That said, I don't think we can leverage current codegen infrastructure, since nms is a torchvision kernel.

What I think could be done: register the XLA implementation manually by:

TORCH_LIBRARY_IMPL(torchvision, XLA, m) {
  m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(xla_nms_kernel));
}

Let me know what you think.

ysiraichi avatar Mar 12 '24 00:03 ysiraichi

I don't think that XLA nms is well tested, it is better to figure out what that ops does and test it before we register it to be the default implemenation.

JackCaoG avatar Mar 12 '24 17:03 JackCaoG

@JackCaoG While the solution in this comment works, I thought it would make more sense to implement a CompositeExplicitAutograd version on TorchVision, directly. What do you think?

ysiraichi avatar Mar 15 '24 14:03 ysiraichi

what difference does it make to mark it as CompositeExplicitAutograd ?

JackCaoG avatar Mar 15 '24 17:03 JackCaoG

The difference is that it would be decomposable with, hopefully, already supported operations. That said, I'm thinking on the following plan:

  • Make an XLA kernel for nms on PyTorch/XLA (original idea)
  • Talk with TorchVision maintainers, and create a CompositeExplicitAutograd kernel for nms
  • Kill XLA kernel implementation

I think that it would be better to have the composite kernel because:

  • It's easier to maintain
  • One less kernel to maintain inside PyTorch/XLA

ysiraichi avatar Mar 15 '24 21:03 ysiraichi

@JackCaoG Question: how important is it to keep the old behavior?

  • Current nms signature:
nms(boxes, scores, score_threshold, iou_threshold, output_size)
  • TorchVision nms signature:
nms(boxes, scores, iou_threshold)

ysiraichi avatar Mar 21 '24 14:03 ysiraichi

ehh, no very? I guess no one using our nms at the moment.

JackCaoG avatar Mar 21 '24 23:03 JackCaoG

So, can we kill it, in favor of the TorchVision variant?

ysiraichi avatar Mar 22 '24 14:03 ysiraichi

yea

JackCaoG avatar Mar 22 '24 17:03 JackCaoG