xla
xla copied to clipboard
Codegen clamp.Tensor
FIxes https://github.com/pytorch/xla/issues/3861
Codegen clamp.Tensor
LazyIr.h:
class ClampTensor : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::clamp);
}
ClampTensor(const torch::lazy::Value& self, const c10::optional<torch::lazy::Value>& min, const c10::optional<torch::lazy::Value>& max, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::clamp),
{self, min.value_or(kNullValue), max.value_or(kNullValue)}, std::move(shapes),
[&]() { return ClampTensorOutputShape(self, min, max); },
/* num_outputs */ 1,
torch::lazy::MHash())
{
has_min = !!min;
has_max = !!max;
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self, const c10::optional<torch::lazy::Value>& min, const c10::optional<torch::lazy::Value>& max) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
bool has_min: 1;
bool has_max: 1;
};
XLANativeFunctions.cpp:
at::Tensor XLANativeFunctions::clamp_max(const at::Tensor & self, const at::Tensor & max) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self, max);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
torch_xla::XLATensorPtr lazy_max = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(max, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<ClampMaxTensor>(lazy_self->GetIrValue(), lazy_max->GetIrValue());
if (!node) {
auto self_meta = to_meta(self);
auto max_meta = to_meta(max);
auto out_meta = at::meta::clamp_max(self_meta, max_meta);
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self, max };
const char* schema_str = "aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<ClampMaxTensor>(lazy_self->GetIrValue(), lazy_max->GetIrValue(), std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
Couple of clamp
unit tests failing:
ERROR: test_clamp_xla_float32 (__main__.TestDevicePrecisionXLA)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 390, in instantiated_test
raise rte
File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 377, in instantiated_test
result = test(self, **param_kwargs)
File "/tmp/pytorch/xla/test/../../test/test_torch.py", line 5653, in test_clamp
self.assertEqual(expect, actual)
File "/tmp/pytorch/xla/test/pytorch_test_base.py", line 628, in assertEqual
x, y = self.prepare_for_compare(x, y)
File "/tmp/pytorch/xla/test/pytorch_test_base.py", line 575, in prepare_for_compare
y = ty.to(device='cpu')
RuntimeError: Error while lowering: [Float[100,50]] aten::clamp
Error: vector::_M_range_check: __n (which is 2) >= this->size() (which is 2)
Frames:
Looking into it.
The error was due to not having promotion logic for two tensors in the shape inference and lowering functions. Latest commit fixes this. All tests are now passing locally.
@wonjoolee95 can you fix the linter?