ao icon indicating copy to clipboard operation
ao copied to clipboard

[RFC] NVFP4 Rounding Modes

Open syed-ahmed opened this issue 2 months ago • 6 comments

The current rounding mode for NVFP4 tensors in TorchAO is round-to-nearest. The purpose of this issue is to discuss support for other rounding modes.

What rounding modes are available?

  • Stochastic Rounding (RS)
  • Round Nearest (RN)
  • Round-zero (RZ)

Where do we need different rounding modes?

  • NVFP4 Training Recipe (https://github.com/pytorch/torchtitan/issues/1962)
    • RS for gradients
    • RN for weights and activation
  • _AdamW in torchao.optim supports BF16 stochastic rounding: https://github.com/pytorch/ao/blob/1e473ed94caa060b3fbac96657030c44173759e0/torchao/optim/README.md?plain=1#L71-L77
  • INT8 Quantization has stochastic rounding mode in TorchAO: https://github.com/pytorch/ao/blob/1e473ed94caa060b3fbac96657030c44173759e0/torchao/prototype/quantized_training/int8.py#L24-L26

Existing RN Kernels

  • Eager path: https://github.com/pytorch/ao/blob/1e473ed94caa060b3fbac96657030c44173759e0/torchao/prototype/custom_fp_utils.py#L27-L30
  • torch.compile: https://github.com/pytorch/ao/blob/1e473ed94caa060b3fbac96657030c44173759e0/torchao/prototype/mx_formats/kernels.py#L1492-L1493
    • Uses cvt.rn.satfinite.e2m1x2.f32 inline asm

Possible RS Kernels implementation

  • Emulated implementation from @slayton58. I've quickly written this in triton syntax but it probably makes most sense to write in pytorch eager similar to RN (_f32_to_floatx_unpacked).
    @triton.jit
    def float_rs(x, seed, offset):
        """
        Apply stochastic rounding when casting from float32 to NVFP4.
    
        Args:
            x: Input tensor (float32)
            seed: Random seed for the random number generator
            offset: Offset for random number generation (should be unique per element)
    
        Returns:
            Stochastically rounded tensor
        """
    
        # Scale down by 2^(-125) to normalize range
        downscale_factor = tl.math.exp2(-125.0)
        x = x * downscale_factor
    
        # Create 32-bit pseudorandom value
        rnd = tl.randint(seed, offset)
    
        # Isolate lower 22 bits for randomness injection
        # Process: left-shift by 10, then right-shift by 10
        rnd_shifted = (rnd << 10) >> 10
    
        # Reinterpret float bits as unsigned integer
        xb = x.to(tl.uint32, bitcast=True)
    
        # Inject randomness into the discarded precision bits
        yb = xb + rnd_shifted
    
        # Clear the lower 22 bits to perform rounding
        yb = (yb >> 22) << 22
    
        # Reinterpret integer bits back as floating point
        y = yb.to(tl.float32, bitcast=True)
    
        # Restore original magnitude by scaling up
        upscale_factor = tl.math.exp2(125.0)
        y = y * upscale_factor
    
        return y
    
  • Use an inline asm triton kernel using cvt.rs.satfinite.e2m1x4.f32 for stochastic rounding similar to RN.

Integration

  • A possible integration point for NVFP4 Training Recipe use case is to specify the rounding mode in to_nvfp4 calls.
    class RoundingMode(Enum):
        RN = "round_nearest"
        RS = "round_stochastic"
        RZ = "round_zero"
    
    def to_nvfp4(
            data_hp: torch.Tensor,
            block_size: int = 16,
            per_tensor_scale: Optional[torch.Tensor] = None,
            act_per_tensor_scale: Optional[torch.Tensor] = None,
            is_swizzled_scales: bool = False,
            use_triton_kernel: bool = False,
            act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None,
            rounding_mode: RoundingMode = RoundingMode.RN,
        ):
        ...
        if use_triton_kernel:
            blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale, rounding_mode)
        else:
            blockwise_scales, data_lp = nvfp4_quantize(
                data_hp, block_size, per_tensor_scale, rounding_mode
            )
    
  • We should discuss if we need to support rounding mode more generically to support other use cases like _AdamW, and int8 training.

Test Plan

  • TODO

CC: @slayton58, @ngimel, @supriyar, @Priyadlfw, @ptrblck, @eqy

syed-ahmed avatar Oct 30 '25 17:10 syed-ahmed

cc @drisspg @vkuzo

danielvegamyhre avatar Oct 30 '25 18:10 danielvegamyhre

+1 and thanks for raising this. How about

  1. we can make the enum you suggested torchao-wide, with default RTNE matching the dtype behavior in PyTorch
  2. callsites can opt-in to taking a non-default value of this enum as needed (priority of 4-bit or lower workflows is highest)
  3. implementation is not part of the BC surface, so can vary by callsite and be shared/not shared as needed.

vkuzo avatar Oct 31 '25 11:10 vkuzo

It could also be part of the existing act_quant_kwargs at first if we want. That's an even smaller API surface, and could be shifted out to a "real" argument later.

In terms of more general RS support, we should be careful to make sure we all understand exactly what RS actually does, and where it has even theoretical benefits (let alone real ones) - it's quite an invasive change, given the need for RNG, and a comparatively expensive one computationally too (again, see RNG).

It's also a pain to test :) Been there, done that, got the ~migraines~ t-shirt.

slayton58 avatar Oct 31 '25 12:10 slayton58

starting this inside MX / under prototype sgtm

vkuzo avatar Oct 31 '25 13:10 vkuzo

Theoretical (and practical) benefit of SR is that it provides unbiased gradient estimate. As for expensive computation - for bf16 it's still bandwidth-bound on H100, nvfp4 might be different as there's much less bw to play with, but there are also faster acceptable ways of generating random bits

ngimel avatar Nov 01 '25 00:11 ngimel

linking nvfp4 training tracker (dense + moe) here as well: https://github.com/pytorch/ao/issues/3293

danielvegamyhre avatar Nov 06 '25 01:11 danielvegamyhre