xla icon indicating copy to clipboard operation
xla copied to clipboard

Codegen `flip` op

Open wonjoolee95 opened this issue 2 years ago • 4 comments

Fixes https://github.com/pytorch/xla/issues/3924


Codegen flip op


LazyIr.h:

class Flip : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::flip);
  }

  Flip(const torch::lazy::Value& self, const ::std::vector<int64_t>& dims, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::flip),
              {self}, std::move(shapes),
              [&]() { return FlipOutputShape(self, dims); },
              /* num_outputs */ 1,
              torch::lazy::MHash(dims)),
        dims(dims)
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    ss << ", dims=" << dims;
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const ::std::vector<int64_t>& dims) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  ::std::vector<int64_t> dims;
  

};

XLANativeFunctions.cpp:

    at::Tensor XLANativeFunctions::flip(const at::Tensor & self, at::IntArrayRef dims) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Flip>(lazy_self->GetIrValue(), std::vector<int64_t>(dims.begin(), dims.end()));
        if (!node) {
            
            auto shapes = torch::lazy::compute_shape_flip(self, dims);
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, dims };
                const char* schema_str = "aten::flip(Tensor self, int[] dims) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<Flip>(lazy_self->GetIrValue(), std::vector<int64_t>(dims.begin(), dims.end()), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

wonjoolee95 avatar Aug 24 '22 22:08 wonjoolee95

ehh ok

======================================================================
FAIL: test_flip_check_throws (__main__.TestAtenXlaTensor)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/pytorch/xla/test/test_operations.py", line 1361, in test_flip_check_throws
    self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1))
AssertionError: RuntimeError not raised by <lambda>

I guess we need to fix this one first.

JackCaoG avatar Aug 25 '22 01:08 JackCaoG

Question question: I notice in the class Flip there is a private member ::std::vector<int64_t> dims;. For me, most of the argument I passed to a XlaNode constructor are converted to type const torch::lazy::Value. I'd like to keep the original type. Just wondering how you created the private member

vanbasten23 avatar Sep 16 '22 18:09 vanbasten23

Question question: I notice in the class Flip there is a private member ::std::vector<int64_t> dims;. For me, most of the argument I passed to a XlaNode constructor are converted to type const torch::lazy::Value. I'd like to keep the original type. Just wondering how you created the private member

For this PR, that was handled by the codegen automatically. The std::vector<int64_t> dims is a parameter to the Flip op, so it's codegen'ed to be a private member. Is there a specific op you're working with that you're seeing this issue?

wonjoolee95 avatar Sep 16 '22 19:09 wonjoolee95

Quick question: I notice in the class Flip there is a private member ::std::vector<int64_t> dims;. For me, most of the argument I passed to a XlaNode constructor are converted to type const torch::lazy::Value. I'd like to keep the original type. Just wondering how you created the private member

For this PR, that was handled by the codegen automatically. The std::vector<int64_t> dims is a parameter to the Flip op, so it's codegen'ed to be a private member. Is there a specific op you're working with that you're seeing this issue?

Yeah, for example, I was trying to codegen clamp which takes in a const c10::optional<at::Scalar>& min. I'd like to make this const c10::optional<at::Scalar>& min a private member of the class Clamp : XlaNode. It seems I have to change the codegen code to make this happen, right?

vanbasten23 avatar Sep 20 '22 00:09 vanbasten23