xla icon indicating copy to clipboard operation
xla copied to clipboard

Codegen lt le

Open JackCaoG opened this issue 2 years ago • 0 comments

Fix https://github.com/pytorch/xla/issues/3872 Fix https://github.com/pytorch/xla/issues/3873 Fix https://github.com/pytorch/xla/issues/3874 Fix https://github.com/pytorch/xla/issues/3875

LazyIr

class LeScalar : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::le);
  }

  LeScalar(const torch::lazy::Value& self, const torch::lazy::Value& other,
           std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::le), {self, other},
                std::move(shapes),
                [&]() { return LeScalarOutputShape(self, other); },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self,
                   const torch::lazy::Value& other) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

class LeTensor : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::le);
  }

  LeTensor(const torch::lazy::Value& self, const torch::lazy::Value& other,
           std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::le), {self, other},
                std::move(shapes),
                [&]() { return LeTensorOutputShape(self, other); },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self,
                   const torch::lazy::Value& other) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class LtScalar : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::lt);
  }

  LtScalar(const torch::lazy::Value& self, const torch::lazy::Value& other,
           std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::lt), {self, other},
                std::move(shapes),
                [&]() { return LtScalarOutputShape(self, other); },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self,
                   const torch::lazy::Value& other) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

class LtTensor : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::lt);
  }

  LtTensor(const torch::lazy::Value& self, const torch::lazy::Value& other,
           std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::lt), {self, other},
                std::move(shapes),
                [&]() { return LtTensorOutputShape(self, other); },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self,
                   const torch::lazy::Value& other) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};


class LtScalar : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::lt);
  }

  LtScalar(const torch::lazy::Value& self, const torch::lazy::Value& other,
           std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::lt), {self, other},
                std::move(shapes),
                [&]() { return LtScalarOutputShape(self, other); },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self,
                   const torch::lazy::Value& other) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

class LtTensor : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::lt);
  }

  LtTensor(const torch::lazy::Value& self, const torch::lazy::Value& other,
           std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::lt), {self, other},
                std::move(shapes),
                [&]() { return LtTensorOutputShape(self, other); },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self,
                   const torch::lazy::Value& other) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
at::Tensor XLANativeFunctions::le(const at::Tensor& self,
                                  const at::Scalar& other) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  auto node_other =
      torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
          other, *common_device);
  torch::lazy::NodePtr node =
      torch::lazy::ReuseNode<LeScalar>(lazy_self->GetIrValue(), node_other);
  if (!node) {
    auto self_meta = to_meta(self);
    auto out_meta = at::meta::le(self_meta, other);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self, other};
      const char* schema_str =
          "aten::le.Scalar(Tensor self, Scalar other) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<LeScalar>(lazy_self->GetIrValue(), node_other,
                                           std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};

at::Tensor XLANativeFunctions::le(const at::Tensor& self,
                                  const at::Tensor& other) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_other =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other,
                                                              *common_device);
  torch::lazy::NodePtr node = torch::lazy::ReuseNode<LeTensor>(
      lazy_self->GetIrValue(), lazy_other->GetIrValue());
  if (!node) {
    auto self_meta = to_meta(self);
    auto other_meta = to_meta(other);
    auto out_meta = at::meta::le(self_meta, other_meta);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self, other};
      const char* schema_str =
          "aten::le.Tensor(Tensor self, Tensor other) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<LeTensor>(
        lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};
at::Tensor XLANativeFunctions::lt(const at::Tensor& self,
                                  const at::Scalar& other) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  auto node_other =
      torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
          other, *common_device);
  torch::lazy::NodePtr node =
      torch::lazy::ReuseNode<LtScalar>(lazy_self->GetIrValue(), node_other);
  if (!node) {
    auto self_meta = to_meta(self);
    auto out_meta = at::meta::lt(self_meta, other);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self, other};
      const char* schema_str =
          "aten::lt.Scalar(Tensor self, Scalar other) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<LtScalar>(lazy_self->GetIrValue(), node_other,
                                           std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};

at::Tensor XLANativeFunctions::lt(const at::Tensor& self,
                                  const at::Tensor& other) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_other =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other,
                                                              *common_device);
  torch::lazy::NodePtr node = torch::lazy::ReuseNode<LtTensor>(
      lazy_self->GetIrValue(), lazy_other->GetIrValue());
  if (!node) {
    auto self_meta = to_meta(self);
    auto other_meta = to_meta(other);
    auto out_meta = at::meta::lt(self_meta, other_meta);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self, other};
      const char* schema_str =
          "aten::lt.Tensor(Tensor self, Tensor other) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<LtTensor>(
        lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};

JackCaoG avatar Aug 13 '22 02:08 JackCaoG