TransformerEngine
TransformerEngine copied to clipboard
[Core] Fix inconsistent logic in C++ tensor class
Description
This PR fixes some hacky logic in the C++ Tensor class:
- Construct uninitialized tensors with
shape=[0]. Previously we constructed them as 0-D tensors, which should have one entry (see https://github.com/NVIDIA/TransformerEngine/pull/2215#discussion_r2427408835). - Make
Tensor::has_dataandTensor::has_columnwise_datahave a consistent meaning: "is it safe to touch the data pointer?". - Fix code that use
Tensor::has_data/Tensor::has_columnwise_datainappropriately (Tensor::shape, swizzling). These are cases where the data is initialized but the pointer is not safe to touch, e.g. when the tensor has a zero dim. Note that I've only fixed places that were causing problems, and I haven't fully gone through the entire code base.
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [x] Code refactoring
Changes
- Construct uninitialized tensors with zero entries.
- Fix logic in
Tensor::has_dataandTensor::has_columnwise_data. - Remove inappropriate usage of
Tensor::has_data/Tensor::has_columnwise_datafromTensor::shapeand swizzling.
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
/te-ci
/te-ci L1
/te-ci
Greptile Overview
Greptile Summary
This PR refactors the C++ Tensor class to fix inconsistent logic around uninitialized tensors. The key changes are:
- Uninitialized tensors now use
shape=[0]instead ofshape=[](0-D tensors should have 1 element) has_data()andhas_columnwise_data()now delegate toSimpleTensor::has_data()which checks if buffer is non-default- Replaced direct pointer checks with
has_data()calls throughout codebase - Fixed inappropriate usage in
shape(), swizzling, and normalization code
The refactoring improves consistency and correctness by distinguishing uninitialized tensors from tensors with zero elements. Most changes are mechanical replacements of pointer checks with has_data() calls. The swizzle.cu changes are more substantial, adding better validation and clearer control flow.
One concern: the has_data() implementation uses complex boolean logic that may not handle all edge cases correctly (e.g., tensors with shape=[1] and null pointer would incorrectly report as initialized).
Confidence Score: 4/5
- This PR is generally safe to merge with minor risk from the
has_data()logic complexity - Score reflects thorough refactoring with consistent changes across the codebase, but one logical issue in
has_data()implementation could cause edge case bugs. The vast majority of changes are mechanical and correct (pointer checks →has_data()calls, empty shape{}→{0}). The swizzle.cu refactoring significantly improves code clarity and validation. However, thehas_data()boolean logic may not correctly distinguish all tensor states, particularly for edge cases like shape=[1] with null pointer. - Pay close attention to
transformer_engine/common/common.hfor thehas_data()logic edge case
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/common.h | 4/5 | Refactored SimpleTensor and Tensor classes to use shape=[0] for uninitialized tensors, updated has_data() logic, and simplified Tensor::shape() function. Changes improve consistency but has_data() logic may be overly complex. |
| transformer_engine/common/transformer_engine.cpp | 5/5 | Replaced direct pointer checks with has_data() calls throughout. Added checks for FP8 vs non-FP8 tensors in CheckScaleTensorShape(). Updated nvte_make_shape() to handle null pointers. Clean and consistent changes. |
| transformer_engine/common/swizzle/swizzle.cu | 4/5 | Restructured swizzle_scaling_factors() to properly use has_data() checks instead of raw pointer checks. Added better validation and clearer control flow with switch statements. Logic appears correct but is complex. |
| transformer_engine/common/include/transformer_engine/transformer_engine.h | 5/5 | Updated TensorWrapper constructor to use emptyShape instead of defaultShape for uninitialized tensors. Added logic to set scale_inv to empty when pointer is null. Consistent and safe changes. |
Sequence Diagram
sequenceDiagram
participant User
participant Tensor
participant SimpleTensor
participant has_data()
Note over Tensor,SimpleTensor: Uninitialized Tensor Construction
User->>Tensor: Tensor()
Tensor->>SimpleTensor: SimpleTensor() default ctor
SimpleTensor->>SimpleTensor: dptr = nullptr<br/>shape = [0]<br/>dtype = kFloat32
Note over Tensor,has_data(): Checking Uninitialized Tensor
User->>Tensor: has_data()
Tensor->>SimpleTensor: data.has_data()
SimpleTensor->>has_data(): Check: !(dptr == nullptr && shape.size() == 1 && shape[0] == 0)
has_data()-->>SimpleTensor: false (uninitialized)
SimpleTensor-->>Tensor: false
Tensor-->>User: false
Note over Tensor,SimpleTensor: Initialized Tensor with Valid Data
User->>Tensor: Set dptr to valid pointer
User->>Tensor: Set shape to [M, N]
User->>Tensor: has_data()
Tensor->>SimpleTensor: data.has_data()
SimpleTensor->>has_data(): Check: !(dptr != nullptr || shape.size() != 1 || shape[0] != 0)
has_data()-->>SimpleTensor: true (initialized)
SimpleTensor-->>Tensor: true
Tensor-->>User: true
Note over Tensor,SimpleTensor: Edge Case: Zero-sized Tensor with dptr
User->>Tensor: Set dptr to valid pointer
User->>Tensor: Set shape to [0] (empty tensor)
User->>Tensor: has_data()
Tensor->>SimpleTensor: data.has_data()
SimpleTensor->>has_data(): Check: dptr != nullptr
has_data()-->>SimpleTensor: true (has pointer, so "initialized")
SimpleTensor-->>Tensor: true
Tensor-->>User: true
/te-ci L1
/te-ci L1
/te-ci L1
/te-ci L1
/te-ci L1