xla
xla copied to clipboard
Codegen lt le
Fix https://github.com/pytorch/xla/issues/3872 Fix https://github.com/pytorch/xla/issues/3873 Fix https://github.com/pytorch/xla/issues/3874 Fix https://github.com/pytorch/xla/issues/3875
LazyIr
class LeScalar : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::le);
}
LeScalar(const torch::lazy::Value& self, const torch::lazy::Value& other,
std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::le), {self, other},
std::move(shapes),
[&]() { return LeScalarOutputShape(self, other); },
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class LeTensor : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::le);
}
LeTensor(const torch::lazy::Value& self, const torch::lazy::Value& other,
std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::le), {self, other},
std::move(shapes),
[&]() { return LeTensorOutputShape(self, other); },
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class LtScalar : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::lt);
}
LtScalar(const torch::lazy::Value& self, const torch::lazy::Value& other,
std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::lt), {self, other},
std::move(shapes),
[&]() { return LtScalarOutputShape(self, other); },
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class LtTensor : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::lt);
}
LtTensor(const torch::lazy::Value& self, const torch::lazy::Value& other,
std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::lt), {self, other},
std::move(shapes),
[&]() { return LtTensorOutputShape(self, other); },
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class LtScalar : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::lt);
}
LtScalar(const torch::lazy::Value& self, const torch::lazy::Value& other,
std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::lt), {self, other},
std::move(shapes),
[&]() { return LtScalarOutputShape(self, other); },
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class LtTensor : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::lt);
}
LtTensor(const torch::lazy::Value& self, const torch::lazy::Value& other,
std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::lt), {self, other},
std::move(shapes),
[&]() { return LtTensorOutputShape(self, other); },
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
at::Tensor XLANativeFunctions::le(const at::Tensor& self,
const at::Scalar& other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
*common_device);
auto node_other =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
other, *common_device);
torch::lazy::NodePtr node =
torch::lazy::ReuseNode<LeScalar>(lazy_self->GetIrValue(), node_other);
if (!node) {
auto self_meta = to_meta(self);
auto out_meta = at::meta::le(self_meta, other);
std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, other};
const char* schema_str =
"aten::le.Scalar(Tensor self, Scalar other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<LeScalar>(lazy_self->GetIrValue(), node_other,
std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::le(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
*common_device);
torch_xla::XLATensorPtr lazy_other =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other,
*common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<LeTensor>(
lazy_self->GetIrValue(), lazy_other->GetIrValue());
if (!node) {
auto self_meta = to_meta(self);
auto other_meta = to_meta(other);
auto out_meta = at::meta::le(self_meta, other_meta);
std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, other};
const char* schema_str =
"aten::le.Tensor(Tensor self, Tensor other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<LeTensor>(
lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::lt(const at::Tensor& self,
const at::Scalar& other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
*common_device);
auto node_other =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
other, *common_device);
torch::lazy::NodePtr node =
torch::lazy::ReuseNode<LtScalar>(lazy_self->GetIrValue(), node_other);
if (!node) {
auto self_meta = to_meta(self);
auto out_meta = at::meta::lt(self_meta, other);
std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, other};
const char* schema_str =
"aten::lt.Scalar(Tensor self, Scalar other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<LtScalar>(lazy_self->GetIrValue(), node_other,
std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::lt(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
*common_device);
torch_xla::XLATensorPtr lazy_other =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other,
*common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<LtTensor>(
lazy_self->GetIrValue(), lazy_other->GetIrValue());
if (!node) {
auto self_meta = to_meta(self);
auto other_meta = to_meta(other);
auto out_meta = at::meta::lt(self_meta, other_meta);
std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, other};
const char* schema_str =
"aten::lt.Tensor(Tensor self, Tensor other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<LtTensor>(
lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};