Codegen addcdiv and addcmul
Fix https://github.com/pytorch/xla/issues/3765, Fix https://github.com/pytorch/xla/issues/3766, Fix https://github.com/pytorch/xla/issues/3767
Example pr of codegen op that takes at::Scalar.
The current codegen upload does not take a scalar type, hence we always upload with the default type (f64). To fix this issue, we need to identify the Value that was a at::scalar in the original aten_xla_type.cpp and cast the value to the correct dtype.
LazyIR
class Addcdiv : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::addcdiv);
}
Addcdiv(const torch::lazy::Value& self, const torch::lazy::Value& tensor1,
const torch::lazy::Value& tensor2, const torch::lazy::Value& value,
std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(
torch::lazy::OpKind(at::aten::addcdiv),
{self, tensor1, tensor2, value}, std::move(shapes),
[&]() { return AddcdivOutputShape(self, tensor1, tensor2, value); },
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& tensor1,
const torch::lazy::Value& tensor2,
const torch::lazy::Value& value) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class Addcmul : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::addcmul);
}
Addcmul(const torch::lazy::Value& self, const torch::lazy::Value& tensor1,
const torch::lazy::Value& tensor2, const torch::lazy::Value& value,
std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(
torch::lazy::OpKind(at::aten::addcmul),
{self, tensor1, tensor2, value}, std::move(shapes),
[&]() { return AddcmulOutputShape(self, tensor1, tensor2, value); },
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& tensor1,
const torch::lazy::Value& tensor2,
const torch::lazy::Value& value) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
XLANativeFunction
at::Tensor XLANativeFunctions::addcdiv(const at::Tensor& self,
const at::Tensor& tensor1,
const at::Tensor& tensor2,
const at::Scalar& value) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self, tensor1, tensor2);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
*common_device);
torch_xla::XLATensorPtr lazy_tensor1 =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor1,
*common_device);
torch_xla::XLATensorPtr lazy_tensor2 =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor2,
*common_device);
auto node_value =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
value, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<Addcdiv>(
lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
lazy_tensor2->GetIrValue(), node_value);
if (!node) {
auto self_meta = to_meta(self);
auto tensor1_meta = to_meta(tensor1);
auto tensor2_meta = to_meta(tensor2);
auto out_meta =
at::meta::addcdiv(self_meta, tensor1_meta, tensor2_meta, value);
std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, tensor1, tensor2, value};
const char* schema_str =
"aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, "
"Scalar value=1) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<Addcdiv>(
lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
lazy_tensor2->GetIrValue(), node_value, std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
at::Tensor XLANativeFunctions::addcmul(const at::Tensor& self,
const at::Tensor& tensor1,
const at::Tensor& tensor2,
const at::Scalar& value) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self, tensor1, tensor2);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
*common_device);
torch_xla::XLATensorPtr lazy_tensor1 =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor1,
*common_device);
torch_xla::XLATensorPtr lazy_tensor2 =
torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor2,
*common_device);
auto node_value =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
value, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<Addcmul>(
lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
lazy_tensor2->GetIrValue(), node_value);
if (!node) {
auto self_meta = to_meta(self);
auto tensor1_meta = to_meta(tensor1);
auto tensor2_meta = to_meta(tensor2);
auto out_meta =
at::meta::addcmul(self_meta, tensor1_meta, tensor2_meta, value);
std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, tensor1, tensor2, value};
const char* schema_str =
"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, "
"Scalar value=1) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<Addcmul>(
lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
lazy_tensor2->GetIrValue(), node_value, std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
https://github.com/pytorch/pytorch/pull/82970#pullrequestreview-1065922141 fixed the device issue.
Issue coming from trying to add f64 and f32
2 root error(s) found.
(0) INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %add.122 = f64[1,16]{1,0} add(f64[1,16]{1,0} %dot.119, f32[1,16]{1,0} %broadcast.121), metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="[email protected]" source_line=114}, but mixed precision is disallowed.
[[{{node XRTCompile}}]]
[[XRTCompile_G3]]
Issue was coming from that when we uploading a at::Scalar to the device, the current codegen upload does not take a scalar type, hence we always upload with the default type (f64). To fix this issue, we need to identify the Value that was a at::scalar in the original aten_xla_type.cpp and cast the value to the correct dtype. FYI @wonjoolee95
@wonjoolee95 I think this one is ready for reivew