TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[Core] Fix inconsistent logic in C++ tensor class

Open timmoon10 opened this issue 1 month ago • 2 comments

Description

This PR fixes some hacky logic in the C++ Tensor class:

  • Construct uninitialized tensors with shape=[0]. Previously we constructed them as 0-D tensors, which should have one entry (see https://github.com/NVIDIA/TransformerEngine/pull/2215#discussion_r2427408835).
  • Make Tensor::has_data and Tensor::has_columnwise_data have a consistent meaning: "is it safe to touch the data pointer?".
  • Fix code that use Tensor::has_data/Tensor::has_columnwise_data inappropriately (Tensor::shape, swizzling). These are cases where the data is initialized but the pointer is not safe to touch, e.g. when the tensor has a zero dim. Note that I've only fixed places that were causing problems, and I haven't fully gone through the entire code base.

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [ ] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [x] Code refactoring

Changes

  • Construct uninitialized tensors with zero entries.
  • Fix logic in Tensor::has_data and Tensor::has_columnwise_data.
  • Remove inappropriate usage of Tensor::has_data/Tensor::has_columnwise_data from Tensor::shape and swizzling.

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have made corresponding changes to the documentation
  • [ ] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

timmoon10 avatar Nov 01 '25 00:11 timmoon10

/te-ci

timmoon10 avatar Nov 01 '25 00:11 timmoon10

/te-ci L1

timmoon10 avatar Nov 03 '25 22:11 timmoon10

/te-ci

timmoon10 avatar Nov 18 '25 05:11 timmoon10

Greptile Overview

Greptile Summary

This PR refactors the C++ Tensor class to fix inconsistent logic around uninitialized tensors. The key changes are:

  • Uninitialized tensors now use shape=[0] instead of shape=[] (0-D tensors should have 1 element)
  • has_data() and has_columnwise_data() now delegate to SimpleTensor::has_data() which checks if buffer is non-default
  • Replaced direct pointer checks with has_data() calls throughout codebase
  • Fixed inappropriate usage in shape(), swizzling, and normalization code

The refactoring improves consistency and correctness by distinguishing uninitialized tensors from tensors with zero elements. Most changes are mechanical replacements of pointer checks with has_data() calls. The swizzle.cu changes are more substantial, adding better validation and clearer control flow.

One concern: the has_data() implementation uses complex boolean logic that may not handle all edge cases correctly (e.g., tensors with shape=[1] and null pointer would incorrectly report as initialized).

Confidence Score: 4/5

  • This PR is generally safe to merge with minor risk from the has_data() logic complexity
  • Score reflects thorough refactoring with consistent changes across the codebase, but one logical issue in has_data() implementation could cause edge case bugs. The vast majority of changes are mechanical and correct (pointer checks → has_data() calls, empty shape {}{0}). The swizzle.cu refactoring significantly improves code clarity and validation. However, the has_data() boolean logic may not correctly distinguish all tensor states, particularly for edge cases like shape=[1] with null pointer.
  • Pay close attention to transformer_engine/common/common.h for the has_data() logic edge case

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/common.h 4/5 Refactored SimpleTensor and Tensor classes to use shape=[0] for uninitialized tensors, updated has_data() logic, and simplified Tensor::shape() function. Changes improve consistency but has_data() logic may be overly complex.
transformer_engine/common/transformer_engine.cpp 5/5 Replaced direct pointer checks with has_data() calls throughout. Added checks for FP8 vs non-FP8 tensors in CheckScaleTensorShape(). Updated nvte_make_shape() to handle null pointers. Clean and consistent changes.
transformer_engine/common/swizzle/swizzle.cu 4/5 Restructured swizzle_scaling_factors() to properly use has_data() checks instead of raw pointer checks. Added better validation and clearer control flow with switch statements. Logic appears correct but is complex.
transformer_engine/common/include/transformer_engine/transformer_engine.h 5/5 Updated TensorWrapper constructor to use emptyShape instead of defaultShape for uninitialized tensors. Added logic to set scale_inv to empty when pointer is null. Consistent and safe changes.

Sequence Diagram

sequenceDiagram
    participant User
    participant Tensor
    participant SimpleTensor
    participant has_data()
    
    Note over Tensor,SimpleTensor: Uninitialized Tensor Construction
    User->>Tensor: Tensor()
    Tensor->>SimpleTensor: SimpleTensor() default ctor
    SimpleTensor->>SimpleTensor: dptr = nullptr<br/>shape = [0]<br/>dtype = kFloat32
    
    Note over Tensor,has_data(): Checking Uninitialized Tensor
    User->>Tensor: has_data()
    Tensor->>SimpleTensor: data.has_data()
    SimpleTensor->>has_data(): Check: !(dptr == nullptr && shape.size() == 1 && shape[0] == 0)
    has_data()-->>SimpleTensor: false (uninitialized)
    SimpleTensor-->>Tensor: false
    Tensor-->>User: false
    
    Note over Tensor,SimpleTensor: Initialized Tensor with Valid Data
    User->>Tensor: Set dptr to valid pointer
    User->>Tensor: Set shape to [M, N]
    User->>Tensor: has_data()
    Tensor->>SimpleTensor: data.has_data()
    SimpleTensor->>has_data(): Check: !(dptr != nullptr || shape.size() != 1 || shape[0] != 0)
    has_data()-->>SimpleTensor: true (initialized)
    SimpleTensor-->>Tensor: true
    Tensor-->>User: true
    
    Note over Tensor,SimpleTensor: Edge Case: Zero-sized Tensor with dptr
    User->>Tensor: Set dptr to valid pointer
    User->>Tensor: Set shape to [0] (empty tensor)
    User->>Tensor: has_data()
    Tensor->>SimpleTensor: data.has_data()
    SimpleTensor->>has_data(): Check: dptr != nullptr
    has_data()-->>SimpleTensor: true (has pointer, so "initialized")
    SimpleTensor-->>Tensor: true
    Tensor-->>User: true

greptile-apps[bot] avatar Nov 18 '25 05:11 greptile-apps[bot]

/te-ci L1

timmoon10 avatar Nov 19 '25 22:11 timmoon10

/te-ci L1

timmoon10 avatar Nov 27 '25 03:11 timmoon10

/te-ci L1

timmoon10 avatar Dec 02 '25 02:12 timmoon10

/te-ci L1

timmoon10 avatar Dec 02 '25 20:12 timmoon10

/te-ci L1

timmoon10 avatar Dec 04 '25 03:12 timmoon10