mojo
mojo copied to clipboard
[BUG] `simd.reduce_*` have inconsistent/wrong semantics when `size_out > 1`
Bug description
reduce_max
is reduce-splat when size == size_out == 2
, and no opt when size == size_out > 2
. It's caused by the special-casing in simd.reduce
:
https://github.com/modularml/mojo/blob/6d2a7b552769358ec68a94b8fd1f6c2126d59ad9/stdlib/src/builtin/simd.mojo#L1766-L1767
I could have opened a PR, but I don't know if it's the intended/correct behaviour (I'd argue that it is not).
In fact, I think reduce
can be simplified.
@always_inline
fn reduce[
func: fn[type: DType, width: Int] (
SIMD[type, width], SIMD[type, width]
) capturing -> SIMD[type, width],
size_out: Int = 1,
](self) -> SIMD[type, size_out]:
constrained[size_out <= size]()
@parameter
if size == size_out: # size == 1 covered by this case since 0 < size_out <= size
return rebind[SIMD[type, size_out]](self)
# size > size_out
alias half_size: Int = size // 2 # half_size >= size_out
var lhs = self.slice[half_size, offset=0]()
var rhs = self.slice[half_size, offset=half_size]()
return func[type, half_size](lhs, rhs).reduce[func, size_out]() # size decreases, so it will terminate
Steps to reproduce
fn main():
alias type = DType.float32
var a = SIMD[type, 2](0, 1)
var b = SIMD[type, 4](0, 1, 2, 3)
var c = SIMD[type, 8](0, 1, 2, 3, 4, 5, 6, 7)
print(a.reduce_max[a.size]()) # [1.0, 1.0]
print(b.reduce_max[b.size]()) # [0.0, 1.0, 2.0, 3.0]
print(c.reduce_max[c.size]()) # [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
System information
Mojo 24.2 on Docker, Intel Mac