iree icon indicating copy to clipboard operation
iree copied to clipboard

Carry encoding in the preferred storage type of a hoistable type

Open jtuyls opened this issue 2 months ago • 3 comments

This PR makes sure that the encoding attribute is included in the preferred storage type of a HoistableTensorType. Without this we can get a tensor.bitcast on an encoded type returning a non-encoded type, which will result in a compilation failure when converting into stream. I also updated the verifier of flow.tensor.bitcast to catch this issue earlier.

Note that I think there could be a potential issues with encoding on types that are bitcast, because how do we guarantee that this will result in the same materialization? Potentially, we should add something like a storage/underlying_type field to encodings to account for this. However, for now, in practice for subbyte types, we're not going to encode them any differently I think. But I can take care of that as well if desired and/or file an issue.

jtuyls avatar Oct 06 '25 20:10 jtuyls

I think this works for our current use cases, but it seems dubious to me, since the padding computation could be different for the source and result of the bitcast. For example:

%bitcast = flow.tensor.bitcast %src : tensor<16xi8, #encoding> -> tensor<32xf4E2M1FN, #encoding>

The calculateStorageSizeInBytes will compute the size by padding each dim to the corresponding inner tile size. Let's say the inner tile size for this encoding was 32. Then, the source would be padded to 32, doubling its size, while the result would not be padded.

The reason it works today is because all flow.tensor.bitcast ops are effectively no-ops, since they are converted into stream.tensor.clone ops which will eventually get folded. In order for this IR to be valid, either the encoding needs to change from source to result of the bitcast, or the encoding needs to verify that the size will be the same before/after the bitcast. It would probably need to happen through another interface function.

Yes, so what do you think about keeping track of the underlying type in the encoding? This could be an optional field that could be specified in the encoding after the bitcast. EDIT: this 'underlying_type' would be specific to the current data-tiling encodings. An interface method like you're suggesting could take care of it in general so encodings can specify how they should be updated on type changes.

Note that I think there could be a potential issues with encoding on types that are bitcast, because how do we guarantee that this will result in the same materialization? Potentially, we should add something like a storage/underlying_type field to encodings to account for this. However, for now, in practice for subbyte types, we're not going to encode them any differently I think. But I can take care of that as well if desired and/or file an issue.

jtuyls avatar Oct 07 '25 15:10 jtuyls

I suspect there may be other issues than this that are the root cause - we should rely on encoding propagation to handle uniformly folding encodings into globals (util.global -> [util.global.load -> bitcast that has uniform encodings] -> folds to util.global with the encoding -> util.global.load).

@benvanik The HoistIntoGlobals pass will insert a new bitcast through the HoistableTensorTypeInterface, so it should have an encoding attached to it right? What do you think about the above suggestion to add a new interface method for the encodings to implement to specify how the encoding should change on a type change?

jtuyls avatar Oct 07 '25 16:10 jtuyls

Below is the output w/o the PR. I think what Ben said makes sense: they are just bytes, and you can interpret the data with tensor types with an optional encoding. They are eventually just a raw pointer plus size for a global; it happens when we convert them to Stream dialect.

#encoding = #iree_encoding.testing<>
module @hoist_subbyte_with_encoding {
  util.global private @__hoisted_tensor_32xi8 : tensor<32xi8>
  util.initializer {
    %cst = arith.constant dense<3> : tensor<64xi4>
    %0 = flow.tensor.encode %cst : tensor<64xi4> -> tensor<64xi4, #encoding>
    %1 = "iree_unregistered.const_expr"(%0) : (tensor<64xi4, #encoding>) -> tensor<64xi4, #encoding>
    %2 = iree_tensor_ext.bitcast %1 : tensor<64xi4, #encoding> -> tensor<32xi8>
    util.global.store %2, @__hoisted_tensor_32xi8 : tensor<32xi8>
    util.return
  }
  util.func public @main() -> tensor<64xi4, #encoding> {
    %__hoisted_tensor_32xi8 = util.global.load immutable @__hoisted_tensor_32xi8 : tensor<32xi8>
    %0 = iree_tensor_ext.bitcast %__hoisted_tensor_32xi8 : tensor<32xi8> -> tensor<64xi4, #encoding>
    util.return %0 : tensor<64xi4, #encoding>
  }
}

The solution to me is that we should introduce a static version of calculateStorageSizeInBytes, which only function for static shapes. The constants and weights have been static shape. We never see dynamic constants in practice, but they can be supported with the interface mechanism. For now, I think we can implement the static version of the interface method and use it here. I.e., the preferred storage size can be tensor<calculateStorageSizeInBytes()xi8>. The attribute controls whether the data are packed or not itself. The mental model is that we store/load plain data to/from the variable; the bitcast op before the use helps you intepret/parse the data, given the encoding.

Note: MLIR support for overloading interface methods since https://github.com/llvm/llvm-project/commit/842622bf8bea782e9d9865ed78b0d8643f098122

(I'm just back from vacation, so I may miss something. We can chat more in tomorrow's sync.)

hanhanW avatar Oct 13 '25 23:10 hanhanW

@hanhanW @benvanik I updated this PR to let encodings specify how they should be converted on a bitcast through a SerializableAttr interface method. Could you have another look?

jtuyls avatar Dec 15 '25 08:12 jtuyls

Fly-by, it looks like there is a subtle point from Ben being missed here:

util.global -> [util.global.load -> bitcast that has uniform encodings]

The point at which we're introducing these bitcasts at the global opt level we do not have uniform encodings yet, most are unspecialized. I question whether the bitcast in the current input is well-formed at all:

%2 = iree_tensor_ext.bitcast %1 : tensor<64xi4, #encoding> -> tensor<32xi8>

Do we still need to introduce the bitcasts for encoded values? The reason we currently introduce them is to keep certain element types from ever making it's way to an iree_hal_buffer_view_t, and the global/initializer boundary was [one of] the places that we couldn't move everything inside a dispatch (I wouldn't be surprised if loop carried sub-byte values don't work for the same reason).

The reason we are able to introduce the bitcasts at the global opt level is because we can compute the size of an unencoded tensor at that point. For encoded ones I think we need to wait until/after specializing, but that should be fine because we have specialization to update values both inside and outside executables.

qedawkins avatar Dec 15 '25 15:12 qedawkins

I am not sure I am entirely following what's missing, but I added a couple of comments below.

The point at which we're introducing these bitcasts at the global opt level we do not have uniform encodings yet, most are unspecialized. I question whether the bitcast in the current input is well-formed at all:

%2 = iree_tensor_ext.bitcast %1 : tensor<64xi4, #encoding> -> tensor<32xi8>

This is not what we should see with this PR. Here, we ask the encoding how it should be transformed on bit cast and we would get something like:

%2 = iree_tensor_ext.bitcast %1 : tensor<64xi4, #encoding1> -> tensor<32xi8, #encoding2>

The reason we are able to introduce the bitcasts at the global opt level is because we can compute the size of an unencoded tensor at that point. For encoded ones I think we need to wait until/after specializing, but that should be fine because we have specialization to update values both inside and outside executables.

Are you saying that we shouldn't hoist bitcast types with encodings at all at this point and hoist later? If so I am not sure though what the issue with hoisting here with updated encodings and then specializing and materializing later?

jtuyls avatar Dec 15 '25 16:12 jtuyls

This is not what we should ever see I think. With this PR, we ask the encoding how it should be transformed on bit cast and we would get something like:

As long as we aren't getting the kind of bitcasts in Hanhan's comment, then I'm ok.

Are you saying that we shouldn't hoist bitcast types with encodings at all at this point and hoist later? If so I am not sure though what the issue with hoisting here with updated encodings and then specializing and materializing later?

I was thinking just hoisting encoded values and leaving the type of the global encoded as well, relying on specialization to select the right size for it (which could itself be computed by a different initializer). I might be missing something on the feasibility of that though.

qedawkins avatar Dec 15 '25 16:12 qedawkins

IIUC, the bitcast op is lowered to stream.tensor.clone op, and the original failure is in verifier. Now we have encodings in the source and the result; they are "acceptable" because they are not serialized yet. The verification is deferred until encoding specialization. Once the encodings are serialized, they are required to be compatible. In this case, they are the same, so we don't see errors.

I can help review, if it matches my understanding.

hanhanW avatar Dec 16 '25 09:12 hanhanW

IIUC, the bitcast op is lowered to stream.tensor.clone op, and the original failure is in verifier. Now we have encodings in the source and the result; they are "acceptable" because they are not serialized yet.

Yes, indeed.

The verification is deferred until encoding specialization. Once the encodings are serialized, they are required to be compatible.

That would be ideal, but currently, I didn't add verification that a bitcast with specialized encodings on source and dest is correct. For this, I think we would need a static size calculation method or a custom verifier that can somehow proof correctness. I wasn't sure whether the static size calculation was worth it right now and the latter custom verifier seemed like it would just do the reverse of the specialization logic (compare tile size A * bit width A == tile size B * bit width B).

The burden is currently on the resolver to ensure correctness.

In this case, they are the same, so we don't see errors.

Yeah, technically, the bitcasts get converted into stream.tensor.clone and get folded away.

jtuyls avatar Dec 16 '25 15:12 jtuyls