xla icon indicating copy to clipboard operation
xla copied to clipboard

Codegen take

Open wonjoolee95 opened this issue 2 years ago • 0 comments

Codegen take

Shape inference PR on PyTorch: https://github.com/pytorch/pytorch/pull/82679


LazyIr.h:

class Take : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::take);
  }

  Take(const torch::lazy::Value& self, const torch::lazy::Value& index, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::take),
              {self, index}, std::move(shapes),
              [&]() { return TakeOutputShape(self, index); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& index) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

XLANativeFunctions.cpp:

     at::Tensor XLANativeFunctions::take(const at::Tensor & self, const at::Tensor & index) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self, index);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch_xla::XLATensorPtr lazy_index = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(index, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Take>(lazy_self->GetIrValue(), lazy_index->GetIrValue());
        if (!node) {
            
            auto shapes = torch::lazy::compute_shape_take(self, index);
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, index };
                const char* schema_str = "aten::take(Tensor self, Tensor index) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<Take>(lazy_self->GetIrValue(), lazy_index->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

wonjoolee95 avatar Aug 02 '22 22:08 wonjoolee95