xla icon indicating copy to clipboard operation
xla copied to clipboard

Codegen for bitwise and, or, xor, and not

Open steventk-g opened this issue 3 years ago • 0 comments

PyTorch PR

https://github.com/pytorch/pytorch/pull/82617

LazyIr.h

class BitwiseAndTensor : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::bitwise_and);
  }

  BitwiseAndTensor(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::bitwise_and),
              {self, other}, std::move(shapes),
              [&]() { return BitwiseAndTensorOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class BitwiseAndScalar : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::bitwise_and);
  }

  BitwiseAndScalar(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::bitwise_and),
              {self, other}, std::move(shapes),
              [&]() { return BitwiseAndScalarOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class BitwiseNot : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::bitwise_not);
  }

  BitwiseNot(const torch::lazy::Value& self, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::bitwise_not),
              {self}, std::move(shapes),
              [&]() { return BitwiseNotOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class BitwiseOrTensor : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::bitwise_or);
  }

  BitwiseOrTensor(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::bitwise_or),
              {self, other}, std::move(shapes),
              [&]() { return BitwiseOrTensorOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class BitwiseOrScalar : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::bitwise_or);
  }

  BitwiseOrScalar(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::bitwise_or),
              {self, other}, std::move(shapes),
              [&]() { return BitwiseOrScalarOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class BitwiseXorTensor : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::bitwise_xor);
  }

  BitwiseXorTensor(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::bitwise_xor),
              {self, other}, std::move(shapes),
              [&]() { return BitwiseXorTensorOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class BitwiseXorScalar : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::bitwise_xor);
  }

  BitwiseXorScalar(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::bitwise_xor),
              {self, other}, std::move(shapes),
              [&]() { return BitwiseXorScalarOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

XLANativeFunctions.cpp

    at::Tensor XLANativeFunctions::bitwise_and(const at::Tensor & self, const at::Tensor & other) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch_xla::XLATensorPtr lazy_other = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseAndTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue());
        if (!node) {
                    auto self_meta = to_meta(self);
        auto other_meta = to_meta(other);
        auto out_meta = at::meta::bitwise_and(self_meta, other_meta);
        
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, other };
                const char* schema_str = "aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<BitwiseAndTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::bitwise_and(const at::Tensor & self, const at::Scalar & other) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        auto node_other =
                torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(other);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseAndScalar>(lazy_self->GetIrValue(), node_other);
        if (!node) {
            
            auto shapes = torch::lazy::compute_shape_bitwise_and(self, other);
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, other };
                const char* schema_str = "aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<BitwiseAndScalar>(lazy_self->GetIrValue(), node_other, std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::bitwise_not(const at::Tensor & self) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseNot>(lazy_self->GetIrValue());
        if (!node) {
                    auto self_meta = to_meta(self);
        auto out_meta = at::meta::bitwise_not(self_meta);
        
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self };
                const char* schema_str = "aten::bitwise_not(Tensor self) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<BitwiseNot>(lazy_self->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::bitwise_or(const at::Tensor & self, const at::Tensor & other) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch_xla::XLATensorPtr lazy_other = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseOrTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue());
        if (!node) {
                    auto self_meta = to_meta(self);
        auto other_meta = to_meta(other);
        auto out_meta = at::meta::bitwise_or(self_meta, other_meta);
        
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, other };
                const char* schema_str = "aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<BitwiseOrTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::bitwise_or(const at::Tensor & self, const at::Scalar & other) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        auto node_other =
                torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(other);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseOrScalar>(lazy_self->GetIrValue(), node_other);
        if (!node) {
            
            auto shapes = torch::lazy::compute_shape_bitwise_or(self, other);
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, other };
                const char* schema_str = "aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<BitwiseOrScalar>(lazy_self->GetIrValue(), node_other, std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor & self, const at::Tensor & other) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch_xla::XLATensorPtr lazy_other = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseXorTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue());
        if (!node) {
                    auto self_meta = to_meta(self);
        auto other_meta = to_meta(other);
        auto out_meta = at::meta::bitwise_xor(self_meta, other_meta);
        
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, other };
                const char* schema_str = "aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<BitwiseXorTensor>(lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor & self, const at::Scalar & other) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        auto node_other =
                torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(other);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<BitwiseXorScalar>(lazy_self->GetIrValue(), node_other);
        if (!node) {
            
            auto shapes = torch::lazy::compute_shape_bitwise_xor(self, other);
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, other };
                const char* schema_str = "aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<BitwiseXorScalar>(lazy_self->GetIrValue(), node_other, std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

Testing

Built with BUILD_CPP_TESTS=0 python setup.py install in docker container on cloud top.

In python shell, ran:

>>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8, device=xm.xla_device()), torch.tensor([1, 0, 3], dtype=torch.int8, device=xm.xla_device()))
tensor([1, 0, 3], device='xla:0', dtype=torch.int8)
>>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8, device=xm.xla_device()), torch.tensor([1, 0, 3], dtype=torch.int8, device=xm.xla_device()))
tensor([-1, -2,  3], device='xla:0', dtype=torch.int8)
>>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8, device=xm.xla_device()), torch.tensor([1, 0, 3], dtype=torch.int8, device=xm.xla_device()))
tensor([-2, -2,  0], device='xla:0', dtype=torch.int8)
>>> torch.bitwise_not(torch.tensor([-1, -2, 3], dtype=torch.int8, device=xm.xla_device()))
tensor([ 0,  1, -4], device='xla:0', dtype=torch.int8)
>>> torch.bitwise_and(torch.tensor([True, True, False], device=xm.xla_device()), torch.tensor([False, True, False], device=xm.xla_device()))
tensor([False,  True, False], device='xla:0')
>>> torch.bitwise_or(torch.tensor([True, True, False], device=xm.xla_device()), torch.tensor([False, True, False], device=xm.xla_device()))
tensor([ True,  True, False], device='xla:0')
>>> torch.bitwise_xor(torch.tensor([True, True, False], device=xm.xla_device()), torch.tensor([False, True, False], device=xm.xla_device()))
tensor([ True, False, False], device='xla:0')
>>> torch.bitwise_not(torch.tensor([True, True, False], device=xm.xla_device()))
tensor([False, False,  True], device='xla:0')

steventk-g avatar Aug 01 '22 22:08 steventk-g