mojo icon indicating copy to clipboard operation
mojo copied to clipboard

[Feature Request] Int64, Int32, Int16 constructor support multiple smaller IntX as input

Open martinvuyk opened this issue 10 months ago • 9 comments

Review Mojo's priorities

What is your request?

It would be helpful if Int64, Int32 and Int16 had constructors that allow for passing smaller SIMD values Int64 using: 2 Int32, 4 Int16, 8 Int8 ... and so on

this could be done assuming that unsigned values are being passed when Int8 are passed (same for the others)

What is your motivation for this change?

Normal casting of Int8 takes way too long for each number to then shift and build an Int64.

Any other details?

The output in my machine takes about 3 times more for the normal casting than using to_int() and shifting the function alternate_cast varies a lot and I imagine it depends a lot on the given pipeline that it's shoved into at runtime, but it always takes less than the normal_cast function

from time import now


fn normal_cast(num: Int8) -> Int64:
    return num.cast[DType.int64]()


fn alternative_cast(num: Int8) -> Int64:
    return num.to_int()


fn int64_init(
    num7: Int8 = 0,
    num6: Int8 = 0,
    num5: Int8 = 0,
    num4: Int8 = 0,
    num3: Int8 = 0,
    num2: Int8 = 0,
    num1: Int8 = 0,
    num0: Int8 = 0,
) -> Int64:
    return (
        num7.to_int() << 56
        | num6.to_int() << 48
        | num5.to_int() << 40
        | num4.to_int() << 32
        | num3.to_int() << 24
        | num2.to_int() << 16
        | num1.to_int() << 8
        | num0.to_int()
    )


fn main():
    var num = Int8(5)
    var start = now()
    print(normal_cast(num))
    print("took: " + String(now() - start))
    start = now()
    print(alternative_cast(num))
    print("took: " + String(now() - start))
    start = now()
    print(int64_init(num0=num))
    print("took: " + String(now() - start))

martinvuyk avatar Apr 22 '24 15:04 martinvuyk

You can use memory.bitcast.

soraros avatar Apr 22 '24 18:04 soraros

I tried using it now that you mentioned it (num100: Int8)

error: 'pop.bitcast' op operand type '!pop.scalar<si8>' and result type '!pop.scalar<si64>' are cast incompatible
  var digits = bitcast[DType.int64](num100) << 8 | bitcast[DType.int64](
                                                    ^
 note: see current operation: %127 = "pop.bitcast"(%arg0) : (!pop.scalar<si8>) -> !pop.scalar<si64>

This is what I had tried before:

var p = DTypePointer[DType.int8].alloc(1)
p.store(0, 2)
var newp = p.bitcast[DType.int64]()

I've tried it in a parallelized function and it doesn't work (no idea why, just got garbage back..), and it also isn't very useful to use heap for something which one knows its size (plus alloc and free overhead). I also tried using BufferDType.int8, 1 and then bitcast its pointer but i got a SegFault at runtime in a parallelized func.

Another option is to implement bit operation between different types of IntX since it would just need to pad the bits with 0, i tried accessing the scalar (Int8().value) underneath but it doesn't have any bit logic implemented

martinvuyk avatar Apr 22 '24 20:04 martinvuyk

Not sure what exactly what you want to achieve, but this should work. Do mind little endian vs big endian though.

from memory import bitcast

fn main():
  var a = SIMD[DType.uint8, 8](0, 1, 2, 3, 4, 5, 6, 7)
  var res = bitcast[DType.int64, 1](a)
  print(hex(res))

soraros avatar Apr 23 '24 09:04 soraros

Sadly I have an input which is a List[Int8] returned by file.read_bytes so I have no concatenation (i.e. type SIMD[DType.int8, 8]) (or I just don't know how to build the type). Using var newp = bitcast[UInt64](chars) and using that as pointer got me a SegFault

One of the places where I'm not sure if there is a performance cost:

@always_inline
fn get_64(chars: AnyPointer[Int8], startname: Int) -> UInt64:
    var n7 = chars[startname + 7].cast[DType.uint64]()
    var n6 = chars[startname + 6].cast[DType.uint64]()
    var n5 = chars[startname + 5].cast[DType.uint64]()
    var n4 = chars[startname + 4].cast[DType.uint64]()
    var n3 = chars[startname + 3].cast[DType.uint64]()
    var n2 = chars[startname + 2].cast[DType.uint64]()
    var n1 = chars[startname + 1].cast[DType.uint64]()
    var n0 = chars[startname + 0].cast[DType.uint64]()

    var b64 = n7 << (8 * 7) | n6 << (8 * 6) | n5 << (8 * 5) | n4 << (8 * 4) | n3 << (
        8 * 3
    ) | n2 << (8 * 2) | n1 << 8 | n0
    return b64

As far as I've found there is no way to steal the data from the list to get it into a DTypePointer[DType.int8] where I imagine bitcasting would work. Side note: Couldn't Big Endian and Little Endian order just be parametrized in the constructor (to have no branches)?

What I haven't found is a way to just do

var new_p = DTypePointer[DType.int8].alloc(8)
memcpy(new_p, chars.offset(startname), 8)

removing offset(startname) (which Anypointer doesn't implement) I get: #1 cannot be converted from 'AnyPointer[SIMD[si8, 1]]' to 'DTypePointer[si8, 0]'

It would be useful being able to do directly (ideally using the stack instead of alloc using the Buffer struct (?) )

var new_p = DTypePointer[DType.uint64].alloc(1)
# count 8 items from src's type 
memcpy[8](new_p, chars.offset(startname))
# at compile time it can be asserted that (dest.simdtype_size / src.simdtype_size) % 2 == 0
# If the bits aren't already initialized to 0 with alloc/stack_alloc
#  var padding = dest.simdtype_size / (src.simdtype_size * count)  - 1
# ...

or some way to directly map the 8 bits to an UInt64 without needing to access the indexes, cast, and shift. Because I would like to be able to do bit operations to all 64 bits at the same time

martinvuyk avatar Apr 23 '24 15:04 martinvuyk

I think it's only going to be easier if you already have the pointer (from the list):

fn main():
    var l = List[UInt8](0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
    var p = DTypePointer(l.data).bitcast[DType.int64]()  # only work on nightly
    print(hex(p[0]))
    _ = l

soraros avatar Apr 23 '24 15:04 soraros

There is no DTypePointer constructor with that signature. And file.read_bytes returns a signed integer (no idea why not an UInt8)

error: no matching function in initialization
    var p = DTypePointer(l.data).bitcast[DType.int64]()
            ~~~~~~~~~~~~^~~~~~~~

martinvuyk avatar Apr 23 '24 16:04 martinvuyk

There is no DTypePointer constructor with that signature.

I'm on nightly. The following should work on release. I'd suggest you to look into the source code of stdlib if you can't find a certain overload.

fn main():
    var l = List[UInt8](0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
    var p = DTypePointer(Pointer(l.data.value)).bitcast[DType.int64]()
    print(hex(p[0]))
    _ = l

soraros avatar Apr 23 '24 16:04 soraros

Thanks a lot that one helped a bunch. I was just reading the docs and trying stuff out, didn't think to go straight to the code. But now that I went there I barely understand MLIR and LLVM stuff :'(

I imagine there is a more generic and concise way to achieve this, but here goes a short demo of what I mean with the feature request

from math.bit import bitreverse


@parameter
fn int64_init[
    little_endian: Bool, T: DType = DType.uint8
](
    n7: Int = 0,
    n6: Int = 0,
    n5: Int = 0,
    n4: Int = 0,
    n3: Int = 0,
    n2: Int = 0,
    n1: Int = 0,
    n0: Int = 0,
) -> Int64:
    if T == DType.int8 or T == DType.uint8:
        if not little_endian:
            return (
                n7 << (8 * 7)
                | n6 << (8 * 6)
                | n5 << (8 * 5)
                | n4 << (8 * 4)
                | n3 << (8 * 3)
                | n2 << (8 * 2)
                | n1 << 8
                | n0
            )
        return (
            n0 << (8 * 7)
            | n1 << (8 * 6)
            | n2 << (8 * 5)
            | n3 << (8 * 4)
            | n4 << (8 * 3)
            | n5 << (8 * 2)
            | n6 << 8
            | n7
        )
    elif T == DType.int16 or T == DType.uint16:
        if not little_endian:
            return n3 << (16 * 3) | n2 << (16 * 2) | n1 << 16 | n0
        return n0 << (16 * 3) | n1 << (16 * 2) | n2 << 16 | n3
    elif T == DType.int32 or T == DType.uint32:
        if not little_endian:
            return n1 << 32 | n0
        return n0 << 32 | n1
    return Int64(0)


@parameter
fn int64_init[
    little_endian: Bool
](
    n7: Int8 = 0,
    n6: Int8 = 0,
    n5: Int8 = 0,
    n4: Int8 = 0,
    n3: Int8 = 0,
    n2: Int8 = 0,
    n1: Int8 = 0,
    n0: Int8 = 0,
) -> Int64:
    return int64_init[little_endian](
        n7.to_int(),
        n6.to_int(),
        n5.to_int(),
        n4.to_int(),
        n3.to_int(),
        n2.to_int(),
        n1.to_int(),
        n0.to_int(),
    )


@parameter
fn int64_init[
    little_endian: Bool
](n3: Int16 = 0, n2: Int16 = 0, n1: Int16 = 0, n0: Int16 = 0,) -> Int64:
    return int64_init[little_endian, DType.int16](
        n3=n3.to_int(), n2=n2.to_int(), n1=n1.to_int(), n0=n0.to_int()
    )


@parameter
fn int64_init[
    little_endian: Bool
](n1: Int32 = 0, n0: Int32 = 0,) -> Int64:
    return int64_init[little_endian, DType.int32](n1=n1.to_int(), n0=n0.to_int())


@parameter
fn int64_init[
    little_endian: Bool = True
](nums: DTypePointer[DType.int8], offset: Int = 0) -> Int64:
    var casted = nums.offset(offset).bitcast[DType.int64]()[0]
    if not little_endian:
        return bitreverse(casted)
    return casted


@parameter
fn int64_init[little_endian: Bool = True](nums: List[Int8], offset: Int = 0) -> Int64:
    var n = DTypePointer(Pointer(nums.data.value))
    return int64_init[little_endian](n, offset)


@parameter
fn int64_init[
    little_endian: Bool = True
](nums: Tuple[Int, Int, Int, Int, Int, Int, Int, Int]) -> Int64:
    var n7 = nums.get[0, Int]()
    var n6 = nums.get[1, Int]()
    var n5 = nums.get[2, Int]()
    var n4 = nums.get[3, Int]()
    var n3 = nums.get[4, Int]()
    var n2 = nums.get[5, Int]()
    var n1 = nums.get[6, Int]()
    var n0 = nums.get[7, Int]()
    return int64_init[little_endian](n7, n6, n5, n4, n3, n2, n1, n0)


@parameter
fn int64_init[
    little_endian: Bool = True
](nums: ListLiteral[Int, Int, Int, Int, Int, Int, Int, Int]) -> Int64:
    var n7 = nums.get[0, Int]()
    var n6 = nums.get[1, Int]()
    var n5 = nums.get[2, Int]()
    var n4 = nums.get[3, Int]()
    var n3 = nums.get[4, Int]()
    var n2 = nums.get[5, Int]()
    var n1 = nums.get[6, Int]()
    var n0 = nums.get[7, Int]()
    return int64_init[little_endian](n7, n6, n5, n4, n3, n2, n1, n0)


fn main():
    var l = List[Int8](1, 0, 0, 0, 0, 0, 0, 0)
    var l2 = DTypePointer[DType.int8].alloc(8)
    l2.store(0, 1)
    for i in range(1, 8):
        l2.store(i, 0)
    print(int64_init(l))
    print(int64_init(l2))
    print(int64_init((1, 0, 0, 0, 0, 0, 0, 0)))
    print(int64_init([1, 0, 0, 0, 0, 0, 0, 0]))
    print(int64_init[little_endian=True](1))
    print(int64_init[little_endian=False](0, 0, 0, 0, 0, 0, 0, 1))

martinvuyk avatar Apr 23 '24 20:04 martinvuyk

I'm sorry I'll be more concise:

Easy and intuitive interoperability between the UIntX types would be nice

var n0 = UInt8(0)
var n1 = UInt8(0xFF)
var n3 = UInt8(3)
var n64 = UInt64(0x303)
print(n64 * n0) # UInt64(0)
print(n64 & n3) # UInt64(3)
# creating a broadcast fn
print(n64 & n1.broadcast[DType.uint64]()) # UInt64(0x303)
# or creating a new constructor
print(n64 & UInt64(n1, n1, n1, n1, n1, n1,n1,n1)

One doubt

var n1 = Uint8(0xFF)
var n2 = Uint16(0xFFFF)
print(n2 & UInt16(n1, n1) # 0xFFFF

This should use 16bit registers, are there more used when using to_int() and shifting 8? my question arises since I don't know whether the underlying scalar type is adapted to the operation itself or if it has a default uint64 register size when interoping when other types return __mlir_op.`pop.cast`[_type = __mlir_type.`!pop.scalar<index>`](rebind[Scalar[type]](self).value

Does it use more? Or is this optimized by the compiler?

Either way doing to_int() or bitcast[...]() Is way too verbose por pythonesque code. At least having the to_int done implicitly would be nice

martinvuyk avatar Apr 25 '24 02:04 martinvuyk