xla
xla copied to clipboard
Codegen for bitwise and, or, xor, and not
PyTorch PR
https://github.com/pytorch/pytorch/pull/82617
LazyIr.h
class BitwiseAndTensor : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::bitwise_and);
}
BitwiseAndTensor(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::bitwise_and),
{self, other}, std::move(shapes),
[&]() { return BitwiseAndTensorOutputShape(self, other); },
/* num_outputs */ 1,
torch::lazy::MHash())
{
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class BitwiseAndScalar : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::bitwise_and);
}
BitwiseAndScalar(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::bitwise_and),
{self, other}, std::move(shapes),
[&]() { return BitwiseAndScalarOutputShape(self, other); },
/* num_outputs */ 1,
torch::lazy::MHash())
{
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class BitwiseNot : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::bitwise_not);
}
BitwiseNot(const torch::lazy::Value& self, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::bitwise_not),
{self}, std::move(shapes),
[&]() { return BitwiseNotOutputShape(self); },
/* num_outputs */ 1,
torch::lazy::MHash())
{
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class BitwiseOrTensor : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::bitwise_or);
}
BitwiseOrTensor(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::bitwise_or),
{self, other}, std::move(shapes),
[&]() { return BitwiseOrTensorOutputShape(self, other); },
/* num_outputs */ 1,
torch::lazy::MHash())
{
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class BitwiseOrScalar : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::bitwise_or);
}
BitwiseOrScalar(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::bitwise_or),
{self, other}, std::move(shapes),
[&]() { return BitwiseOrScalarOutputShape(self, other); },
/* num_outputs */ 1,
torch::lazy::MHash())
{
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class BitwiseXorTensor : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::bitwise_xor);
}
BitwiseXorTensor(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::bitwise_xor),
{self, other}, std::move(shapes),
[&]() { return BitwiseXorTensorOutputShape(self, other); },
/* num_outputs */ 1,
torch::lazy::MHash())
{
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class BitwiseXorScalar : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::bitwise_xor);
}
BitwiseXorScalar(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::bitwise_xor),
{self, other}, std::move(shapes),
[&]() { return BitwiseXorScalarOutputShape(self, other); },
/* num_outputs */ 1,
torch::lazy::MHash())
{
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
XLANativeFunctions.cpp
at::Tensor XLANativeFunctions::bitwise_and(const at::Tensor & self, const at::Tensor & other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
torch_xla::XLATensorPtr lazy_other = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseAndTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue());
if (!node) {
auto self_meta = to_meta(self);
auto other_meta = to_meta(other);
auto out_meta = at::meta::bitwise_and(self_meta, other_meta);
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self, other };
const char* schema_str = "aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<BitwiseAndTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::bitwise_and(const at::Tensor & self, const at::Scalar & other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
auto node_other =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(other);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseAndScalar>(lazy_self->GetIrValue(), node_other);
if (!node) {
auto shapes = torch::lazy::compute_shape_bitwise_and(self, other);
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self, other };
const char* schema_str = "aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<BitwiseAndScalar>(lazy_self->GetIrValue(), node_other, std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::bitwise_not(const at::Tensor & self) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseNot>(lazy_self->GetIrValue());
if (!node) {
auto self_meta = to_meta(self);
auto out_meta = at::meta::bitwise_not(self_meta);
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self };
const char* schema_str = "aten::bitwise_not(Tensor self) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<BitwiseNot>(lazy_self->GetIrValue(), std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::bitwise_or(const at::Tensor & self, const at::Tensor & other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
torch_xla::XLATensorPtr lazy_other = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseOrTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue());
if (!node) {
auto self_meta = to_meta(self);
auto other_meta = to_meta(other);
auto out_meta = at::meta::bitwise_or(self_meta, other_meta);
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self, other };
const char* schema_str = "aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<BitwiseOrTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::bitwise_or(const at::Tensor & self, const at::Scalar & other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
auto node_other =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(other);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseOrScalar>(lazy_self->GetIrValue(), node_other);
if (!node) {
auto shapes = torch::lazy::compute_shape_bitwise_or(self, other);
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self, other };
const char* schema_str = "aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<BitwiseOrScalar>(lazy_self->GetIrValue(), node_other, std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor & self, const at::Tensor & other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
torch_xla::XLATensorPtr lazy_other = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseXorTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue());
if (!node) {
auto self_meta = to_meta(self);
auto other_meta = to_meta(other);
auto out_meta = at::meta::bitwise_xor(self_meta, other_meta);
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self, other };
const char* schema_str = "aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<BitwiseXorTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor & self, const at::Scalar & other) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
auto node_other =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(other);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseXorScalar>(lazy_self->GetIrValue(), node_other);
if (!node) {
auto shapes = torch::lazy::compute_shape_bitwise_xor(self, other);
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self, other };
const char* schema_str = "aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<BitwiseXorScalar>(lazy_self->GetIrValue(), node_other, std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
Testing
Built with BUILD_CPP_TESTS=0 python setup.py install in docker container on cloud top.
In python shell, ran:
>>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8, device=xm.xla_device()), torch.tensor([1, 0, 3], dtype=torch.int8, device=xm.xla_device()))
tensor([1, 0, 3], device='xla:0', dtype=torch.int8)
>>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8, device=xm.xla_device()), torch.tensor([1, 0, 3], dtype=torch.int8, device=xm.xla_device()))
tensor([-1, -2, 3], device='xla:0', dtype=torch.int8)
>>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8, device=xm.xla_device()), torch.tensor([1, 0, 3], dtype=torch.int8, device=xm.xla_device()))
tensor([-2, -2, 0], device='xla:0', dtype=torch.int8)
>>> torch.bitwise_not(torch.tensor([-1, -2, 3], dtype=torch.int8, device=xm.xla_device()))
tensor([ 0, 1, -4], device='xla:0', dtype=torch.int8)
>>> torch.bitwise_and(torch.tensor([True, True, False], device=xm.xla_device()), torch.tensor([False, True, False], device=xm.xla_device()))
tensor([False, True, False], device='xla:0')
>>> torch.bitwise_or(torch.tensor([True, True, False], device=xm.xla_device()), torch.tensor([False, True, False], device=xm.xla_device()))
tensor([ True, True, False], device='xla:0')
>>> torch.bitwise_xor(torch.tensor([True, True, False], device=xm.xla_device()), torch.tensor([False, True, False], device=xm.xla_device()))
tensor([ True, False, False], device='xla:0')
>>> torch.bitwise_not(torch.tensor([True, True, False], device=xm.xla_device()))
tensor([False, False, True], device='xla:0')