moose
moose copied to clipboard
Macro magic for logical dialect?
The purpose of the logical dialect is to wrap all tensors as a single logical tensor
pub enum Tensor {
Fixed64(Fixed64Tensor),
Fixed128(Fixed128Tensor),
Float32(Float32Tensor),
Float64(Float64Tensor),
...
}
however operations may only allow certain combinations of these "dtypes", say adding fixed points of same precision.
One way of implementing this selection logic would be using a simple
hybrid_kernel! {
LogicalAddOp,
[
(Placement, (Tensor, Tensor) -> Tensor => Self::kernel),
]
}
and then do the matching inside Self::kernel
impl LogicalAddOp {
fn kernel<S: Session>(
sess: &S,
plc: &Placement,
x: Tensor,
y: Tensor,
) -> Tensor
{
match (x, y) {
(Tensor::Fixed64(v), Tensor::Fixed64(w)) => {
let result = with_context!(plc, sess, x + y);
Tensor::Fixed64(result)
},
(Tensor::Fixed128(v), Tensor::Fixed128(w)) => {
let result = with_context!(plc, sess, x + y);
Tensor::Fixed128(result)
},
_ => unimplemented!() // TODO
}
}
}
This pattern is similar to what the dispatch macros and kernels are doing, so perhaps that's a way to essentially reuse that, say by moving it into hybrid_kernel!
hybrid_kernel! {
LogicalAddOp,
[
(Placement, (Tensor::Fixed64, Tensor::Fixed64) -> Tensor::Fixed64 => Self::kernel),
(Placement, (Tensor::Fixed128, Tensor::Fixed128) -> Tensor::Fixed128 => Self::kernel),
]
}
Note that it is not even clear that we want to semantics so a discussion should be had first.
cc @voronaam