AITemplate
AITemplate copied to clipboard
Correct jagged total_length's upper bound
Summary:
If the upper bound of the total_length
dimension is set to a larger value than B * N (N being the logical max. sequence length), this would not change the correctness of the computation, as the actual value of the total_length
is determined in the runtime and must be consistent with the offsets
. On the other hand, memory planning is happening in compile time, based on the upper bounds of all dimensions in the tensor shapes. In the aforementioned case, the memory planned for storing the jagged tensor's source (and derivative tensors) may be substantially larger than required.
In this diff, a new possibility is added to correct over- or under-specified upper bound of the total_length
dimension in the jagged tensor (source) shape. This is done based on the batch_dim
and jagged_dims
in the JaggedIntVar
which are assumed to contain the correct upper bounds for the batch and jagged dimension (the latter being the max. sequence length). E.g., when inferred from the dense inputs (or additional arguments) of the jagged-aware operators (like jagged elementwise
).
Differential Revision: D45184342
This pull request was exported from Phabricator. Differential Revision: D45184342
This pull request was exported from Phabricator. Differential Revision: D45184342
This pull request was exported from Phabricator. Differential Revision: D45184342
This pull request was exported from Phabricator. Differential Revision: D45184342
This pull request was exported from Phabricator. Differential Revision: D45184342
This pull request was exported from Phabricator. Differential Revision: D45184342
This pull request was exported from Phabricator. Differential Revision: D45184342