xla
xla copied to clipboard
Codegen `flip` op
Fixes https://github.com/pytorch/xla/issues/3924
Codegen flip
op
LazyIr.h:
class Flip : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::flip);
}
Flip(const torch::lazy::Value& self, const ::std::vector<int64_t>& dims, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::flip),
{self}, std::move(shapes),
[&]() { return FlipOutputShape(self, dims); },
/* num_outputs */ 1,
torch::lazy::MHash(dims)),
dims(dims)
{
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
ss << ", dims=" << dims;
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self, const ::std::vector<int64_t>& dims) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
::std::vector<int64_t> dims;
};
XLANativeFunctions.cpp:
at::Tensor XLANativeFunctions::flip(const at::Tensor & self, at::IntArrayRef dims) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<Flip>(lazy_self->GetIrValue(), std::vector<int64_t>(dims.begin(), dims.end()));
if (!node) {
auto shapes = torch::lazy::compute_shape_flip(self, dims);
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self, dims };
const char* schema_str = "aten::flip(Tensor self, int[] dims) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<Flip>(lazy_self->GetIrValue(), std::vector<int64_t>(dims.begin(), dims.end()), std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
ehh ok
======================================================================
FAIL: test_flip_check_throws (__main__.TestAtenXlaTensor)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/tmp/pytorch/xla/test/test_operations.py", line 1361, in test_flip_check_throws
self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1))
AssertionError: RuntimeError not raised by <lambda>
I guess we need to fix this one first.
Question question: I notice in the class Flip there is a private member ::std::vector<int64_t> dims;
. For me, most of the argument I passed to a XlaNode constructor are converted to type const torch::lazy::Value
. I'd like to keep the original type. Just wondering how you created the private member
Question question: I notice in the class Flip there is a private member
::std::vector<int64_t> dims;
. For me, most of the argument I passed to a XlaNode constructor are converted to typeconst torch::lazy::Value
. I'd like to keep the original type. Just wondering how you created the private member
For this PR, that was handled by the codegen automatically. The std::vector<int64_t> dims
is a parameter to the Flip
op, so it's codegen'ed to be a private member. Is there a specific op you're working with that you're seeing this issue?
Quick question: I notice in the class Flip there is a private member
::std::vector<int64_t> dims;
. For me, most of the argument I passed to a XlaNode constructor are converted to typeconst torch::lazy::Value
. I'd like to keep the original type. Just wondering how you created the private memberFor this PR, that was handled by the codegen automatically. The
std::vector<int64_t> dims
is a parameter to theFlip
op, so it's codegen'ed to be a private member. Is there a specific op you're working with that you're seeing this issue?
Yeah, for example, I was trying to codegen clamp which takes in a const c10::optional<at::Scalar>& min
. I'd like to make this const c10::optional<at::Scalar>& min
a private member of the class Clamp : XlaNode
. It seems I have to change the codegen code to make this happen, right?