xla
xla copied to clipboard
Codegen ne
Fix https://github.com/pytorch/xla/issues/3880 Fix https://github.com/pytorch/xla/issues/3881
LazyIr
class NeScalar : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::ne);
}
NeScalar(const torch::lazy::Value& self, const torch::lazy::Value& other,
std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::ne), {self, other},
std::move(shapes),
[&]() { return NeScalarOutputShape(self, other); },
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class NeTensor : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::ne);
}
NeTensor(const torch::lazy::Value& self, const torch::lazy::Value& other,
std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::ne), {self, other},
std::move(shapes),
[&]() { return NeTensorOutputShape(self, other); },
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
NativeFunction
at::Tensor XLANativeFunctions::ne(const at::Tensor& self,
const at::Scalar& other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
*common_device);
auto node_other =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
other, *common_device);
torch::lazy::NodePtr node =
torch::lazy::ReuseNode<NeScalar>(lazy_self->GetIrValue(), node_other);
if (!node) {
auto self_meta = to_meta(self);
auto out_meta = at::meta::ne(self_meta, other);
std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, other};
const char* schema_str =
"aten::ne.Scalar(Tensor self, Scalar other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<NeScalar>(lazy_self->GetIrValue(), node_other,
std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::ne(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
*common_device);
torch_xla::XLATensorPtr lazy_other =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other,
*common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<NeTensor>(
lazy_self->GetIrValue(), lazy_other->GetIrValue());
if (!node) {
auto self_meta = to_meta(self);
auto other_meta = to_meta(other);
auto out_meta = at::meta::ne(self_meta, other_meta);
std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, other};
const char* schema_str =
"aten::ne.Tensor(Tensor self, Tensor other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<NeTensor>(
lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};