xla
xla copied to clipboard
Codegen take
Codegen take
Shape inference PR on PyTorch: https://github.com/pytorch/pytorch/pull/82679
LazyIr.h:
class Take : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::take);
}
Take(const torch::lazy::Value& self, const torch::lazy::Value& index, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::take),
{self, index}, std::move(shapes),
[&]() { return TakeOutputShape(self, index); },
/* num_outputs */ 1,
torch::lazy::MHash())
{
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& index) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
XLANativeFunctions.cpp:
at::Tensor XLANativeFunctions::take(const at::Tensor & self, const at::Tensor & index) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self, index);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
torch_xla::XLATensorPtr lazy_index = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(index, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<Take>(lazy_self->GetIrValue(), lazy_index->GetIrValue());
if (!node) {
auto shapes = torch::lazy::compute_shape_take(self, index);
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self, index };
const char* schema_str = "aten::take(Tensor self, Tensor index) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<Take>(lazy_self->GetIrValue(), lazy_index->GetIrValue(), std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};