tvm
tvm copied to clipboard
[SVE] Change the dtype of Ramp and Broadcast lanes to PrimExpr
This change will allow us to express scalable vectors through Ramp and Broadcast nodes, e.g.
vec = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale())
We will use negative values for runtime::DataType the encode the scalable lane values, e.g. the above example would result in lanes = -4. That's because the lanes in runtime::DataType are tied to DLPack standard which uses uint16_t for the lanes. The conversion happens in the node definitions and runtime::DataType, so the int and uint16_t values should never be exposed to the API user, especially after the string support has been added.
Also include the TVMScript support for scalable Ramp and Broadcasts.
Note that this patch doesn't include lowering to the appropriate LLVM vectors, support for data type string representation or LoopVectorizer support. All of these will be part of future patches.
cc @neildhickey @lhutton1 @tqchen @cbalint13 @Lunderberg @Anndrey24
@tvm-bot rerun
if it is not high effort, consider add https://github.com/apache/tvm/blob/main/python/tvm/ir/json_compact.py so previously serialized node can be loaded
Thanks @tqchen, @Lunderberg and @lhutton1 for your feedback, I uploaded a reworked version of the patch. Here's what's changed:
- Separation of
vscalemultiplier and fixed length vector lanes APIs inruntime::DataType- now we access these constants viavscale_factor()andlanes()methods is_vector()is retierd now and replaced withis_scalable_vector(),is_fixed_length_vector()andis_scalable_or_fixed_length_vector()- Refactor of the function that extracts the integer multiplier from a lanes expression as per @Lunderberg's suggestions
- Removed the
ScalableLanesfunction. Actually, the new form ofExtractVscaleFactordoes the job there, so I used that. I reckon that
reads somewhat weird, so LMK if you think it would be better to wrap it into something more self-documenting.if (!arith::ExtractVscaleFactor(lanes.Eval())): ... - Now that an attempt to fetch
lanes()on a scalable vector results in an error, I decided to change the pattern in codegens
intoICHECK(!op->dtype.is_scalable()) << "Scalable vectors are not supported in codegen_c_host"; int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
Which essentially pushes the error intoint lanes = op->dtype.lanes();runtime::Datatype. This has the advantage of reducing the logic in codegens that is not really related to these codegens. - Implemented JSON serialisation support such that graphs serialised with older versions of TVM can be correctly loaded in versions that include the changes is this patch. Unfortunately, in case of a strategic choice of lanes value, a graph serialised with an older version of TVM can be loaded as an incorrect graph without triggrering an error that would then trigger the
upgrade_jsonfunction, so now we have to force the json upgrade every time we try to load a serialised graph.
Thanks @ekalda @tqchen @Lunderberg