xla icon indicating copy to clipboard operation
xla copied to clipboard

Codegen clamp.Tensor

Open wonjoolee95 opened this issue 2 years ago • 2 comments

FIxes https://github.com/pytorch/xla/issues/3861

Codegen clamp.Tensor


LazyIr.h:

class ClampTensor : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::clamp);
  }

  ClampTensor(const torch::lazy::Value& self, const c10::optional<torch::lazy::Value>& min, const c10::optional<torch::lazy::Value>& max, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::clamp),
              {self, min.value_or(kNullValue), max.value_or(kNullValue)}, std::move(shapes),
              [&]() { return ClampTensorOutputShape(self, min, max); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    has_min = !!min;
    has_max = !!max;
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const c10::optional<torch::lazy::Value>& min, const c10::optional<torch::lazy::Value>& max) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  bool has_min: 1;
  bool has_max: 1;

};

XLANativeFunctions.cpp:

    at::Tensor XLANativeFunctions::clamp_max(const at::Tensor & self, const at::Tensor & max) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self, max);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch_xla::XLATensorPtr lazy_max = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(max, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<ClampMaxTensor>(lazy_self->GetIrValue(), lazy_max->GetIrValue());
        if (!node) {
                    auto self_meta = to_meta(self);
        auto max_meta = to_meta(max);
        auto out_meta = at::meta::clamp_max(self_meta, max_meta);
        
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, max };
                const char* schema_str = "aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<ClampMaxTensor>(lazy_self->GetIrValue(), lazy_max->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

wonjoolee95 avatar Aug 11 '22 18:08 wonjoolee95

Couple of clamp unit tests failing:

ERROR: test_clamp_xla_float32 (__main__.TestDevicePrecisionXLA)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 390, in instantiated_test
    raise rte
  File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 377, in instantiated_test
    result = test(self, **param_kwargs)
  File "/tmp/pytorch/xla/test/../../test/test_torch.py", line 5653, in test_clamp
    self.assertEqual(expect, actual)
  File "/tmp/pytorch/xla/test/pytorch_test_base.py", line 628, in assertEqual
    x, y = self.prepare_for_compare(x, y)
  File "/tmp/pytorch/xla/test/pytorch_test_base.py", line 575, in prepare_for_compare
    y = ty.to(device='cpu')
RuntimeError: Error while lowering: [Float[100,50]] aten::clamp
Error: vector::_M_range_check: __n (which is 2) >= this->size() (which is 2)
Frames:

Looking into it.

wonjoolee95 avatar Aug 12 '22 03:08 wonjoolee95

The error was due to not having promotion logic for two tensors in the shape inference and lowering functions. Latest commit fixes this. All tests are now passing locally.

wonjoolee95 avatar Aug 12 '22 23:08 wonjoolee95

@wonjoolee95 can you fix the linter?

JackCaoG avatar Aug 16 '22 20:08 JackCaoG