xla icon indicating copy to clipboard operation
xla copied to clipboard

Codegen addcdiv and addcmul

Open JackCaoG opened this issue 3 years ago • 4 comments

Fix https://github.com/pytorch/xla/issues/3765, Fix https://github.com/pytorch/xla/issues/3766, Fix https://github.com/pytorch/xla/issues/3767

Example pr of codegen op that takes at::Scalar.

The current codegen upload does not take a scalar type, hence we always upload with the default type (f64). To fix this issue, we need to identify the Value that was a at::scalar in the original aten_xla_type.cpp and cast the value to the correct dtype.

LazyIR

class Addcdiv : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::addcdiv);
  }

  Addcdiv(const torch::lazy::Value& self, const torch::lazy::Value& tensor1,
          const torch::lazy::Value& tensor2, const torch::lazy::Value& value,
          std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(
            torch::lazy::OpKind(at::aten::addcdiv),
            {self, tensor1, tensor2, value}, std::move(shapes),
            [&]() { return AddcdivOutputShape(self, tensor1, tensor2, value); },
            /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self,
                   const torch::lazy::Value& tensor1,
                   const torch::lazy::Value& tensor2,
                   const torch::lazy::Value& value) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

class Addcmul : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::addcmul);
  }

  Addcmul(const torch::lazy::Value& self, const torch::lazy::Value& tensor1,
          const torch::lazy::Value& tensor2, const torch::lazy::Value& value,
          std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(
            torch::lazy::OpKind(at::aten::addcmul),
            {self, tensor1, tensor2, value}, std::move(shapes),
            [&]() { return AddcmulOutputShape(self, tensor1, tensor2, value); },
            /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self,
                   const torch::lazy::Value& tensor1,
                   const torch::lazy::Value& tensor2,
                   const torch::lazy::Value& value) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

XLANativeFunction

at::Tensor XLANativeFunctions::addcdiv(const at::Tensor& self,
                                       const at::Tensor& tensor1,
                                       const at::Tensor& tensor2,
                                       const at::Scalar& value) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self, tensor1, tensor2);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_tensor1 =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor1,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_tensor2 =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor2,
                                                              *common_device);
  auto node_value =
      torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
          value, *common_device);
  torch::lazy::NodePtr node = torch::lazy::ReuseNode<Addcdiv>(
      lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
      lazy_tensor2->GetIrValue(), node_value);
  if (!node) {
    auto self_meta = to_meta(self);
    auto tensor1_meta = to_meta(tensor1);
    auto tensor2_meta = to_meta(tensor2);
    auto out_meta =
        at::meta::addcdiv(self_meta, tensor1_meta, tensor2_meta, value);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self, tensor1, tensor2, value};
      const char* schema_str =
          "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, "
          "Scalar value=1) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<Addcdiv>(
        lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
        lazy_tensor2->GetIrValue(), node_value, std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};

at::Tensor XLANativeFunctions::addcmul(const at::Tensor& self,
                                       const at::Tensor& tensor1,
                                       const at::Tensor& tensor2,
                                       const at::Scalar& value) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self, tensor1, tensor2);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_tensor1 =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor1,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_tensor2 =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor2,
                                                              *common_device);
  auto node_value =
      torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
          value, *common_device);
  torch::lazy::NodePtr node = torch::lazy::ReuseNode<Addcmul>(
      lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
      lazy_tensor2->GetIrValue(), node_value);
  if (!node) {
    auto self_meta = to_meta(self);
    auto tensor1_meta = to_meta(tensor1);
    auto tensor2_meta = to_meta(tensor2);
    auto out_meta =
        at::meta::addcmul(self_meta, tensor1_meta, tensor2_meta, value);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self, tensor1, tensor2, value};
      const char* schema_str =
          "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, "
          "Scalar value=1) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<Addcmul>(
        lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
        lazy_tensor2->GetIrValue(), node_value, std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};

JackCaoG avatar Jul 26 '22 04:07 JackCaoG

https://github.com/pytorch/pytorch/pull/82970#pullrequestreview-1065922141 fixed the device issue.

JackCaoG avatar Aug 09 '22 00:08 JackCaoG

Issue coming from trying to add f64 and f32

2 root error(s) found.
  (0) INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %add.122 = f64[1,16]{1,0} add(f64[1,16]{1,0} %dot.119, f32[1,16]{1,0} %broadcast.121), metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="[email protected]" source_line=114}, but mixed precision is disallowed.
	 [[{{node XRTCompile}}]]
	 [[XRTCompile_G3]]

JackCaoG avatar Aug 10 '22 00:08 JackCaoG

Issue was coming from that when we uploading a at::Scalar to the device, the current codegen upload does not take a scalar type, hence we always upload with the default type (f64). To fix this issue, we need to identify the Value that was a at::scalar in the original aten_xla_type.cpp and cast the value to the correct dtype. FYI @wonjoolee95

JackCaoG avatar Aug 10 '22 00:08 JackCaoG

@wonjoolee95 I think this one is ready for reivew

JackCaoG avatar Aug 12 '22 23:08 JackCaoG