[RFC] NVFP4 Rounding Modes
The current rounding mode for NVFP4 tensors in TorchAO is round-to-nearest. The purpose of this issue is to discuss support for other rounding modes.
What rounding modes are available?
- Stochastic Rounding (RS)
- Round Nearest (RN)
- Round-zero (RZ)
Where do we need different rounding modes?
- NVFP4 Training Recipe (https://github.com/pytorch/torchtitan/issues/1962)
- RS for gradients
- RN for weights and activation
- _AdamW in torchao.optim supports BF16 stochastic rounding: https://github.com/pytorch/ao/blob/1e473ed94caa060b3fbac96657030c44173759e0/torchao/optim/README.md?plain=1#L71-L77
- INT8 Quantization has stochastic rounding mode in TorchAO: https://github.com/pytorch/ao/blob/1e473ed94caa060b3fbac96657030c44173759e0/torchao/prototype/quantized_training/int8.py#L24-L26
Existing RN Kernels
- Eager path: https://github.com/pytorch/ao/blob/1e473ed94caa060b3fbac96657030c44173759e0/torchao/prototype/custom_fp_utils.py#L27-L30
- torch.compile: https://github.com/pytorch/ao/blob/1e473ed94caa060b3fbac96657030c44173759e0/torchao/prototype/mx_formats/kernels.py#L1492-L1493
- Uses
cvt.rn.satfinite.e2m1x2.f32 inline asm
- Uses
Possible RS Kernels implementation
- Emulated implementation from @slayton58. I've quickly written this in triton syntax but it probably makes most sense to write in pytorch eager similar to RN (
_f32_to_floatx_unpacked).@triton.jit def float_rs(x, seed, offset): """ Apply stochastic rounding when casting from float32 to NVFP4. Args: x: Input tensor (float32) seed: Random seed for the random number generator offset: Offset for random number generation (should be unique per element) Returns: Stochastically rounded tensor """ # Scale down by 2^(-125) to normalize range downscale_factor = tl.math.exp2(-125.0) x = x * downscale_factor # Create 32-bit pseudorandom value rnd = tl.randint(seed, offset) # Isolate lower 22 bits for randomness injection # Process: left-shift by 10, then right-shift by 10 rnd_shifted = (rnd << 10) >> 10 # Reinterpret float bits as unsigned integer xb = x.to(tl.uint32, bitcast=True) # Inject randomness into the discarded precision bits yb = xb + rnd_shifted # Clear the lower 22 bits to perform rounding yb = (yb >> 22) << 22 # Reinterpret integer bits back as floating point y = yb.to(tl.float32, bitcast=True) # Restore original magnitude by scaling up upscale_factor = tl.math.exp2(125.0) y = y * upscale_factor return y - Use an inline asm triton kernel using
cvt.rs.satfinite.e2m1x4.f32for stochastic rounding similar to RN.
Integration
- A possible integration point for NVFP4 Training Recipe use case is to specify the rounding mode in
to_nvfp4calls.class RoundingMode(Enum): RN = "round_nearest" RS = "round_stochastic" RZ = "round_zero" def to_nvfp4( data_hp: torch.Tensor, block_size: int = 16, per_tensor_scale: Optional[torch.Tensor] = None, act_per_tensor_scale: Optional[torch.Tensor] = None, is_swizzled_scales: bool = False, use_triton_kernel: bool = False, act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None, rounding_mode: RoundingMode = RoundingMode.RN, ): ... if use_triton_kernel: blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale, rounding_mode) else: blockwise_scales, data_lp = nvfp4_quantize( data_hp, block_size, per_tensor_scale, rounding_mode ) - We should discuss if we need to support rounding mode more generically to support other use cases like _AdamW, and int8 training.
Test Plan
- TODO
CC: @slayton58, @ngimel, @supriyar, @Priyadlfw, @ptrblck, @eqy
cc @drisspg @vkuzo
+1 and thanks for raising this. How about
- we can make the enum you suggested torchao-wide, with default RTNE matching the dtype behavior in PyTorch
- callsites can opt-in to taking a non-default value of this enum as needed (priority of 4-bit or lower workflows is highest)
- implementation is not part of the BC surface, so can vary by callsite and be shared/not shared as needed.
It could also be part of the existing act_quant_kwargs at first if we want. That's an even smaller API surface, and could be shifted out to a "real" argument later.
In terms of more general RS support, we should be careful to make sure we all understand exactly what RS actually does, and where it has even theoretical benefits (let alone real ones) - it's quite an invasive change, given the need for RNG, and a comparatively expensive one computationally too (again, see RNG).
It's also a pain to test :) Been there, done that, got the ~migraines~ t-shirt.
starting this inside MX / under prototype sgtm
Theoretical (and practical) benefit of SR is that it provides unbiased gradient estimate. As for expensive computation - for bf16 it's still bandwidth-bound on H100, nvfp4 might be different as there's much less bw to play with, but there are also faster acceptable ways of generating random bits
linking nvfp4 training tracker (dense + moe) here as well: https://github.com/pytorch/ao/issues/3293